ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

使用元数据生成模型接口

开发者可以使用 TensorFlow Lite 元数据生成封装容器代码,以实现在 Android 上的集成。对于大多数开发者来说,Android Studio 机器学习模型绑定的图形界面最易于使用。如果您需要更多的自定义或正在使用命令行工具,也可以使用 TensorFlow Lite Codegen

使用 Android Studio 机器学习模型绑定

对于使用元数据增强的 TensorFlow Lite 模型,开发者可以使用 Android Studio 机器学习模型绑定来自动配置项目设置,并基于模型元数据生成封装容器类。封装容器代码消除了直接与 ByteBuffer 交互的需要。相反,开发者可以使用 BitmapRect 等类型化对象与 TensorFlow Lite 模型进行交互。

注:需要 Android Studio 4.1 或以上版本

在 Android Studio 中导入 TensorFlow Lite 模型

  1. 右键点击要使用 TFLite 模型的模块,或者点击 File,然后依次点击 New>Other>TensorFlow Lite Model Right-click menus to access the TensorFlow Lite import functionality

  2. 选择 TFLite 文件的位置。请注意,该工具将使用机器学习绑定代您配置模块的依赖关系,且所有依赖关系会自动插入 Android 模块的 build.gradle 文件。

    可选:如果要使用 GPU 加速,请选择导入 TensorFlow GPU 的第二个复选框。Import dialog for TFLite model

  3. 点击 Finish

  4. 导入成功后,会出现以下界面。要开始使用该模型,请选择 Kotlin 或 Java,复制并粘贴 Sample Code 部分的代码。在 Android Studio 中双击 ml 目录下的 TFLite 模型,可以返回此界面。Model details page in Android Studio

加速模型推断

机器学习模型绑定为开发者提供了一种通过使用委托和线程数量来加速代码的方式。

注:TensorFlow Lite 解释器必须在其运行时的同一个线程上创建。不然,TfLiteGpuDelegate Invoke: GpuDelegate 必须在初始化它的同一线程上运行。否则可能会发生错误。

步骤 1. 检查模块 build.gradle 文件是否包含以下依赖关系:

    dependencies {
        ...
        // TFLite GPU delegate 2.3.0 or above is required.
        implementation 'org.tensorflow:tensorflow-lite-gpu:2.3.0'
    }

步骤 2. 检测设备上运行的 GPU 是否兼容 TensorFlow GPU 委托,如不兼容,则使用多个 CPU 线程运行模型:

Kotlin

    import org.tensorflow.lite.gpu.CompatibilityList
    import org.tensorflow.lite.gpu.GpuDelegate
</div>
<pre data-md-type="block_code" data-md-language=""><code>GL_CODE_13</code>

用 TensorFlow Lite 代码生成器生成模型接口 {:#codegen}

注:TensorFlow Lite 封装容器代码生成器目前只支持 Android。

对于使用元数据增强的 TensorFlow Lite 模型,开发者可以使用 TensorFlow Lite Android 封装容器代码生成器来创建特定平台的封装容器代码。封装容器代码消除了直接与 ByteBuffer 交互的需要。相反,开发者可以使用 BitmapRect 等类型化对象与 TensorFlow Lite 模型进行交互。

代码生成器是否有用取决于 TensorFlow Lite 模型的元数据条目是否完整。请参考 metadata_schema.fbs 中相关字段下的 <Codegen usage> 部分,查看代码生成器工具如何解析每个字段。

生成封装容器代码

您需要在终端中安装以下工具:

pip install tflite-support

完成后,可以使用以下句法来使用代码生成器:

tflite_codegen --model=./model_with_metadata/mobilenet_v1_0.75_160_quantized.tflite \
    --package_name=org.tensorflow.lite.classify \
    --model_class_name=MyClassifierModel \
    --destination=./classify_wrapper

生成的代码将位于目标目录中。如果您使用的是 Google Colab 或其他远程环境,将结果压缩成 zip 归档并将其下载到您的 Android Studio 项目中可能会更加容易:

# Zip up the generated code
!zip -r classify_wrapper.zip classify_wrapper/

# Download the archive
from google.colab import files
files.download('classify_wrapper.zip')

使用生成的代码

步骤1:导入生成的代码

如有必要,将生成的代码解压缩到目录结构中。假定生成的代码的根目录为 SRC_ROOT

打开要使用 TensorFlow lite 模型的 Android Studio 项目,然后通过以下步骤导入生成的模块:File -> New -> Import Module -> 选择 SRC_ROOT

使用上面的示例,导入的目录和模块将称为 classify_wrapper

步骤 2:更新应用的 build.gradle 文件

在将使用生成的库模块的应用模块中:

在 android 部分下,添加以下内容:

aaptOptions {
   noCompress "tflite"
}

在 android 部分添加以下内容:

implementation project(":classify_wrapper")

步骤 3:使用模型

// 1. Initialize the model
MyClassifierModel myImageClassifier = null;

try {
    myImageClassifier = new MyClassifierModel(this);
} catch (IOException io){
    // Error reading the model
}

if(null != myImageClassifier) {

    // 2. Set the input with a Bitmap called inputBitmap
    MyClassifierModel.Inputs inputs = myImageClassifier.createInputs();
    inputs.loadImage(inputBitmap));

    // 3. Run the model
    MyClassifierModel.Outputs outputs = myImageClassifier.run(inputs);

    // 4. Retrieve the result
    Map<String, Float> labeledProbability = outputs.getProbability();
}

加速模型推断

生成的代码为开发者提供了一种通过使用委托和线程数来加速代码的方式。这些可以在初始化模型对象时设置,因为它需要三个参数:

  • Context:Android 活动或服务的上下文
  • (可选)Device:TFLite 加速委托,例如 GPUDelegate 或 NNAPIDelegate
  • (可选) numThreads:用于运行模型的线程数(默认为 1)。

例如,要使用 NNAPI 委托和最多三个线程,您可以像下面这样初始化模型:

try {
    myImageClassifier = new MyClassifierModel(this, Model.Device.NNAPI, 3);
} catch (IOException io){
    // Error reading the model
}

问题排查

如果您遇到
'java.io.FileNotFoundException: This file can not be opened as a file descriptor; it is probably compressed' 错误,请在将使用库模块的应用模块的 android 部分插入以下各行:

aaptOptions {
   noCompress "tflite"
}