使用 TensorFlow Lite Model Maker 进行目标检测

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本

在此 CoLab 笔记本中,您将学习如何使用 TensorFlow Lite Model Maker 库来训练能够在移动设备上检测图像中的沙拉的自定义目标检测模型。

Model Maker 库使用迁移学习来简化使用自定义数据集训练 TensorFlow Lite 模型的过程。使用您自己的自定义数据集重新训练 TensorFlow Lite 模型可以减少所需的训练数据量,并将缩短训练时间。

您将使用公开可用的 Salad 数据集,该数据集创建自 Open Images Dataset V4

该数据集中的每个图像都包含标记为以下其中一类的对象:

  • 烘焙食品
  • 奶酪
  • 沙拉
  • 海鲜
  • 番茄

该数据集包含指定每个对象所在位置的边界框以及对象的标签。

以下是数据集中的示例图像:


先决条件

安装所需的软件包

首先安装所需的包,包括来自 GitHub 仓库的 Model Maker 软件包和将用于评估的 pycocotools 库。

sudo apt -y install libportaudio2
pip install -q --use-deprecated=legacy-resolver tflite-model-maker
pip install -q pycocotools
pip install -q opencv-python-headless==4.1.2.30
pip uninstall -y tensorflow && pip install -q tensorflow==2.8.0

导入所需的软件包。

import numpy as np
import os

from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

import tensorflow as tf
assert tf.__version__.startswith('2')

tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

准备数据集

在这里,您将使用与 AutoML 快速入门相同的数据集。

Salads 数据集可从以下地址获得:gs://cloud-ml-data/img/openimage/csv/salads_ml_use.csv.

其中包含 175 个用于训练的图像,25 个用于验证的图像,以及 25 个用于测试的图像。数据集有五个类:SaladSeafoodTomatoBaked goodsCheese


数据集以 CSV 格式提供:

TRAINING,gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg,Salad,0.0,0.0954,,,0.977,0.957,,
VALIDATION,gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg,Seafood,0.0154,0.1538,,,1.0,0.802,,
TEST,gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg,Tomato,0.0,0.655,,,0.231,0.839,,
  • 每一行对应于一个定位在较大图像中的对象,每个对象被专门指定为测试、训练或验证数据。在本笔记本的后面阶段,您将了解关于这么做的意义的更多信息。
  • 这里包含的三行表示同一图像中的三个不同对象,可从以下地址获得:gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg
  • 每一行都有不同的标签:SaladSeafoodTomato 等。
  • 使用左上角和右下角顶点为每个图像指定边界框。

以下是这三行的可视化效果:


如果您想了解有关如何准备您自己的 CSV 文件以及创建有效数据集的最低要求的更多信息,请参阅准备您的训练数据指南了解更多详细信息。

如果您是 Google Cloud 的新用户,您可能想知道 gs:// 网址是什么意思。它们是存储在 Google Cloud Storage (GCS) 上的文件的网址。如果您在 GCS 上公开您的文件或验证您的客户端,Model Maker 可以像读取本地文件一样读取这些文件。

然而,您不需要将图片保存在 Google Cloud 上就可以使用 Model Maker。您可以在 CSV 文件中使用本地路径,Model Maker 将正常工作。

快速入门

训练目标检测模型有六个步骤:

第 1 步:选择目标检测模型架构。

本教程使用 EfficientDet-Lite0 模型。EfficientDet-Lite[0-4] 是一系列移动/物联网友好的目标检测模型,派生自 EfficientDet 架构。

以下是每种 EfficientDet-Lite 模型之间的性能对比。

模型架构 大小 (MB)* 延迟 (ms)** 平均精度***
EfficientDet-Lite0 4.4 37 25.69%
EfficientDet-Lite1 5.8 49 30.55%
EfficientDet-Lite2 7.2 69 33.97%
EfficientDet-Lite3 11.4 116 37.70%
EfficientDet-Lite4 19.9 260 41.96%

* 整数量化模型的大小。
** 延迟在使用 4 个 CPU 线程的 Pixel 4 上测得。
*** 平均精度是 COCO 2017 验证数据集上的 mAP(平均精度均值)。

spec = model_spec.get('efficientdet_lite0')

第 2 步:加载数据集。

Model Maker 将接收 CSV 格式的输入数据。使用 object_detector.DataLoader.from_csv 方法加载数据集,并将其分割为训练、验证和测试图像。

  • 训练图像:这些图像用于训练目标检测模型识别沙拉成分。
  • 验证图像:这些是模型在训练过程中没有见过的图像。您将使用它们来决定何时应停止训练,以避免过拟合
  • 测试图像:这些图像用于评估最终模型的性能。

您可以直接从 Google Cloud Storage 加载 CSV 文件,但不需要在 Google Cloud 上保留您的图像来使用 Model Maker。您可以在计算机上指定一个本地 CSV 文件,Model Maker 即可正常工作。

train_data, validation_data, test_data = object_detector.DataLoader.from_csv('gs://cloud-ml-data/img/openimage/csv/salads_ml_use.csv')

第 3 步:用训练数据训练 TensorFlow 模型。

  • EfficientDet-Lite0 模型默认使用 epochs = 50,这意味着它将对训练数据集进行 50 次遍历。您可以在训练期间查看验证准确率,并提前停止,以避免过拟合。
  • 在此处设置 batch_size = 8,这样您将看到,遍历训练数据集中的 175 个图像需要 21 个步骤。
  • 设置 train_whole_model=True 可以对整个模型进行微调,而不仅仅是训练头层来提高准确率。代价是训练模型可能需要更长的时间。
model = object_detector.create(train_data, model_spec=spec, batch_size=8, train_whole_model=True, validation_data=validation_data)
Epoch 1/50
21/21 [==============================] - 41s 432ms/step - det_loss: 1.7675 - cls_loss: 1.1338 - box_loss: 0.0127 - reg_l2_loss: 0.0635 - loss: 1.8310 - learning_rate: 0.0090 - gradient_norm: 0.7692 - val_det_loss: 1.6300 - val_cls_loss: 1.0870 - val_box_loss: 0.0109 - val_reg_l2_loss: 0.0635 - val_loss: 1.6936
Epoch 2/50
21/21 [==============================] - 6s 284ms/step - det_loss: 1.6326 - cls_loss: 1.0827 - box_loss: 0.0110 - reg_l2_loss: 0.0635 - loss: 1.6961 - learning_rate: 0.0100 - gradient_norm: 0.9392 - val_det_loss: 1.4099 - val_cls_loss: 0.9169 - val_box_loss: 0.0099 - val_reg_l2_loss: 0.0635 - val_loss: 1.4735
Epoch 3/50
21/21 [==============================] - 6s 288ms/step - det_loss: 1.4586 - cls_loss: 0.9606 - box_loss: 0.0100 - reg_l2_loss: 0.0635 - loss: 1.5221 - learning_rate: 0.0099 - gradient_norm: 1.8510 - val_det_loss: 1.4634 - val_cls_loss: 1.0030 - val_box_loss: 0.0092 - val_reg_l2_loss: 0.0635 - val_loss: 1.5269
Epoch 4/50
21/21 [==============================] - 6s 313ms/step - det_loss: 1.2882 - cls_loss: 0.8228 - box_loss: 0.0093 - reg_l2_loss: 0.0636 - loss: 1.3517 - learning_rate: 0.0099 - gradient_norm: 2.2037 - val_det_loss: 1.4783 - val_cls_loss: 1.0472 - val_box_loss: 0.0086 - val_reg_l2_loss: 0.0636 - val_loss: 1.5418
Epoch 5/50
21/21 [==============================] - 12s 573ms/step - det_loss: 1.1467 - cls_loss: 0.7332 - box_loss: 0.0083 - reg_l2_loss: 0.0636 - loss: 1.2103 - learning_rate: 0.0098 - gradient_norm: 1.8869 - val_det_loss: 1.1540 - val_cls_loss: 0.7317 - val_box_loss: 0.0084 - val_reg_l2_loss: 0.0636 - val_loss: 1.2176
Epoch 6/50
21/21 [==============================] - 6s 295ms/step - det_loss: 1.0630 - cls_loss: 0.6670 - box_loss: 0.0079 - reg_l2_loss: 0.0636 - loss: 1.1266 - learning_rate: 0.0097 - gradient_norm: 1.8586 - val_det_loss: 1.0742 - val_cls_loss: 0.6978 - val_box_loss: 0.0075 - val_reg_l2_loss: 0.0636 - val_loss: 1.1378
Epoch 7/50
21/21 [==============================] - 7s 318ms/step - det_loss: 1.0248 - cls_loss: 0.6464 - box_loss: 0.0076 - reg_l2_loss: 0.0636 - loss: 1.0884 - learning_rate: 0.0096 - gradient_norm: 1.7937 - val_det_loss: 0.9156 - val_cls_loss: 0.5667 - val_box_loss: 0.0070 - val_reg_l2_loss: 0.0636 - val_loss: 0.9792
Epoch 8/50
21/21 [==============================] - 6s 288ms/step - det_loss: 0.9946 - cls_loss: 0.6278 - box_loss: 0.0073 - reg_l2_loss: 0.0636 - loss: 1.0582 - learning_rate: 0.0094 - gradient_norm: 1.8340 - val_det_loss: 0.8883 - val_cls_loss: 0.5422 - val_box_loss: 0.0069 - val_reg_l2_loss: 0.0636 - val_loss: 0.9519
Epoch 9/50
21/21 [==============================] - 6s 283ms/step - det_loss: 0.9602 - cls_loss: 0.6179 - box_loss: 0.0068 - reg_l2_loss: 0.0636 - loss: 1.0238 - learning_rate: 0.0093 - gradient_norm: 1.9670 - val_det_loss: 0.8432 - val_cls_loss: 0.5169 - val_box_loss: 0.0065 - val_reg_l2_loss: 0.0636 - val_loss: 0.9068
Epoch 10/50
21/21 [==============================] - 7s 366ms/step - det_loss: 0.9316 - cls_loss: 0.5941 - box_loss: 0.0067 - reg_l2_loss: 0.0636 - loss: 0.9952 - learning_rate: 0.0091 - gradient_norm: 2.0556 - val_det_loss: 0.8252 - val_cls_loss: 0.5355 - val_box_loss: 0.0058 - val_reg_l2_loss: 0.0636 - val_loss: 0.8889
Epoch 11/50
21/21 [==============================] - 6s 314ms/step - det_loss: 0.9019 - cls_loss: 0.5770 - box_loss: 0.0065 - reg_l2_loss: 0.0636 - loss: 0.9656 - learning_rate: 0.0089 - gradient_norm: 1.9285 - val_det_loss: 0.7875 - val_cls_loss: 0.5036 - val_box_loss: 0.0057 - val_reg_l2_loss: 0.0636 - val_loss: 0.8511
Epoch 12/50
21/21 [==============================] - 6s 283ms/step - det_loss: 0.8818 - cls_loss: 0.5675 - box_loss: 0.0063 - reg_l2_loss: 0.0636 - loss: 0.9455 - learning_rate: 0.0087 - gradient_norm: 1.9203 - val_det_loss: 0.8129 - val_cls_loss: 0.5313 - val_box_loss: 0.0056 - val_reg_l2_loss: 0.0636 - val_loss: 0.8766
Epoch 13/50
21/21 [==============================] - 6s 289ms/step - det_loss: 0.8463 - cls_loss: 0.5471 - box_loss: 0.0060 - reg_l2_loss: 0.0636 - loss: 0.9099 - learning_rate: 0.0085 - gradient_norm: 1.9731 - val_det_loss: 0.8389 - val_cls_loss: 0.5602 - val_box_loss: 0.0056 - val_reg_l2_loss: 0.0636 - val_loss: 0.9026
Epoch 14/50
21/21 [==============================] - 6s 282ms/step - det_loss: 0.8532 - cls_loss: 0.5572 - box_loss: 0.0059 - reg_l2_loss: 0.0637 - loss: 0.9169 - learning_rate: 0.0082 - gradient_norm: 2.2388 - val_det_loss: 0.7846 - val_cls_loss: 0.5136 - val_box_loss: 0.0054 - val_reg_l2_loss: 0.0637 - val_loss: 0.8483
Epoch 15/50
21/21 [==============================] - 8s 389ms/step - det_loss: 0.8212 - cls_loss: 0.5365 - box_loss: 0.0057 - reg_l2_loss: 0.0637 - loss: 0.8849 - learning_rate: 0.0080 - gradient_norm: 2.1374 - val_det_loss: 0.7583 - val_cls_loss: 0.4988 - val_box_loss: 0.0052 - val_reg_l2_loss: 0.0637 - val_loss: 0.8219
Epoch 16/50
21/21 [==============================] - 6s 289ms/step - det_loss: 0.8107 - cls_loss: 0.5229 - box_loss: 0.0058 - reg_l2_loss: 0.0637 - loss: 0.8744 - learning_rate: 0.0077 - gradient_norm: 2.1675 - val_det_loss: 0.8215 - val_cls_loss: 0.5593 - val_box_loss: 0.0052 - val_reg_l2_loss: 0.0637 - val_loss: 0.8852
Epoch 17/50
21/21 [==============================] - 6s 286ms/step - det_loss: 0.7978 - cls_loss: 0.5132 - box_loss: 0.0057 - reg_l2_loss: 0.0637 - loss: 0.8615 - learning_rate: 0.0075 - gradient_norm: 2.3357 - val_det_loss: 0.8640 - val_cls_loss: 0.5992 - val_box_loss: 0.0053 - val_reg_l2_loss: 0.0637 - val_loss: 0.9277
Epoch 18/50
21/21 [==============================] - 6s 293ms/step - det_loss: 0.7642 - cls_loss: 0.5042 - box_loss: 0.0052 - reg_l2_loss: 0.0637 - loss: 0.8278 - learning_rate: 0.0072 - gradient_norm: 2.3080 - val_det_loss: 0.7775 - val_cls_loss: 0.5312 - val_box_loss: 0.0049 - val_reg_l2_loss: 0.0637 - val_loss: 0.8412
Epoch 19/50
21/21 [==============================] - 6s 284ms/step - det_loss: 0.7664 - cls_loss: 0.4978 - box_loss: 0.0054 - reg_l2_loss: 0.0637 - loss: 0.8301 - learning_rate: 0.0069 - gradient_norm: 2.1915 - val_det_loss: 0.7483 - val_cls_loss: 0.5196 - val_box_loss: 0.0046 - val_reg_l2_loss: 0.0637 - val_loss: 0.8120
Epoch 20/50
21/21 [==============================] - 8s 388ms/step - det_loss: 0.7418 - cls_loss: 0.4831 - box_loss: 0.0052 - reg_l2_loss: 0.0637 - loss: 0.8055 - learning_rate: 0.0066 - gradient_norm: 2.2534 - val_det_loss: 0.7673 - val_cls_loss: 0.5250 - val_box_loss: 0.0048 - val_reg_l2_loss: 0.0637 - val_loss: 0.8310
Epoch 21/50
21/21 [==============================] - 6s 280ms/step - det_loss: 0.8007 - cls_loss: 0.5087 - box_loss: 0.0058 - reg_l2_loss: 0.0637 - loss: 0.8644 - learning_rate: 0.0063 - gradient_norm: 2.5656 - val_det_loss: 0.7869 - val_cls_loss: 0.5317 - val_box_loss: 0.0051 - val_reg_l2_loss: 0.0637 - val_loss: 0.8506
Epoch 22/50
21/21 [==============================] - 6s 295ms/step - det_loss: 0.7374 - cls_loss: 0.4861 - box_loss: 0.0050 - reg_l2_loss: 0.0637 - loss: 0.8011 - learning_rate: 0.0060 - gradient_norm: 2.2847 - val_det_loss: 0.8170 - val_cls_loss: 0.5612 - val_box_loss: 0.0051 - val_reg_l2_loss: 0.0637 - val_loss: 0.8807
Epoch 23/50
21/21 [==============================] - 6s 283ms/step - det_loss: 0.7384 - cls_loss: 0.4833 - box_loss: 0.0051 - reg_l2_loss: 0.0637 - loss: 0.8021 - learning_rate: 0.0056 - gradient_norm: 2.2279 - val_det_loss: 0.8068 - val_cls_loss: 0.5482 - val_box_loss: 0.0052 - val_reg_l2_loss: 0.0637 - val_loss: 0.8705
Epoch 24/50
21/21 [==============================] - 7s 322ms/step - det_loss: 0.7370 - cls_loss: 0.4840 - box_loss: 0.0051 - reg_l2_loss: 0.0637 - loss: 0.8007 - learning_rate: 0.0053 - gradient_norm: 2.3582 - val_det_loss: 0.8889 - val_cls_loss: 0.6238 - val_box_loss: 0.0053 - val_reg_l2_loss: 0.0637 - val_loss: 0.9526
Epoch 25/50
21/21 [==============================] - 7s 364ms/step - det_loss: 0.7079 - cls_loss: 0.4719 - box_loss: 0.0047 - reg_l2_loss: 0.0637 - loss: 0.7716 - learning_rate: 0.0050 - gradient_norm: 2.6073 - val_det_loss: 0.8224 - val_cls_loss: 0.5445 - val_box_loss: 0.0056 - val_reg_l2_loss: 0.0637 - val_loss: 0.8861
Epoch 26/50
21/21 [==============================] - 6s 288ms/step - det_loss: 0.7276 - cls_loss: 0.4669 - box_loss: 0.0052 - reg_l2_loss: 0.0637 - loss: 0.7913 - learning_rate: 0.0047 - gradient_norm: 2.4797 - val_det_loss: 0.8051 - val_cls_loss: 0.5411 - val_box_loss: 0.0053 - val_reg_l2_loss: 0.0637 - val_loss: 0.8688
Epoch 27/50
21/21 [==============================] - 6s 290ms/step - det_loss: 0.7037 - cls_loss: 0.4558 - box_loss: 0.0050 - reg_l2_loss: 0.0637 - loss: 0.7674 - learning_rate: 0.0044 - gradient_norm: 2.3431 - val_det_loss: 0.8115 - val_cls_loss: 0.5432 - val_box_loss: 0.0054 - val_reg_l2_loss: 0.0637 - val_loss: 0.8752
Epoch 28/50
21/21 [==============================] - 6s 308ms/step - det_loss: 0.7116 - cls_loss: 0.4590 - box_loss: 0.0051 - reg_l2_loss: 0.0637 - loss: 0.7753 - learning_rate: 0.0040 - gradient_norm: 2.7340 - val_det_loss: 0.8266 - val_cls_loss: 0.5705 - val_box_loss: 0.0051 - val_reg_l2_loss: 0.0637 - val_loss: 0.8904
Epoch 29/50
21/21 [==============================] - 6s 278ms/step - det_loss: 0.6887 - cls_loss: 0.4471 - box_loss: 0.0048 - reg_l2_loss: 0.0637 - loss: 0.7524 - learning_rate: 0.0037 - gradient_norm: 2.4508 - val_det_loss: 0.7780 - val_cls_loss: 0.5375 - val_box_loss: 0.0048 - val_reg_l2_loss: 0.0637 - val_loss: 0.8417
Epoch 30/50
21/21 [==============================] - 7s 359ms/step - det_loss: 0.6970 - cls_loss: 0.4540 - box_loss: 0.0049 - reg_l2_loss: 0.0637 - loss: 0.7607 - learning_rate: 0.0034 - gradient_norm: 2.4654 - val_det_loss: 0.7527 - val_cls_loss: 0.5348 - val_box_loss: 0.0044 - val_reg_l2_loss: 0.0637 - val_loss: 0.8164
Epoch 31/50
21/21 [==============================] - 6s 285ms/step - det_loss: 0.6573 - cls_loss: 0.4318 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7210 - learning_rate: 0.0031 - gradient_norm: 2.3340 - val_det_loss: 0.7645 - val_cls_loss: 0.5354 - val_box_loss: 0.0046 - val_reg_l2_loss: 0.0637 - val_loss: 0.8282
Epoch 32/50
21/21 [==============================] - 6s 279ms/step - det_loss: 0.6636 - cls_loss: 0.4338 - box_loss: 0.0046 - reg_l2_loss: 0.0637 - loss: 0.7273 - learning_rate: 0.0028 - gradient_norm: 2.3318 - val_det_loss: 0.7956 - val_cls_loss: 0.5594 - val_box_loss: 0.0047 - val_reg_l2_loss: 0.0637 - val_loss: 0.8593
Epoch 33/50
21/21 [==============================] - 6s 316ms/step - det_loss: 0.6770 - cls_loss: 0.4470 - box_loss: 0.0046 - reg_l2_loss: 0.0637 - loss: 0.7407 - learning_rate: 0.0025 - gradient_norm: 2.5163 - val_det_loss: 0.7844 - val_cls_loss: 0.5614 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8482
Epoch 34/50
21/21 [==============================] - 6s 287ms/step - det_loss: 0.6811 - cls_loss: 0.4467 - box_loss: 0.0047 - reg_l2_loss: 0.0637 - loss: 0.7449 - learning_rate: 0.0023 - gradient_norm: 2.5778 - val_det_loss: 0.7602 - val_cls_loss: 0.5361 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8240
Epoch 35/50
21/21 [==============================] - 8s 370ms/step - det_loss: 0.6745 - cls_loss: 0.4449 - box_loss: 0.0046 - reg_l2_loss: 0.0637 - loss: 0.7382 - learning_rate: 0.0020 - gradient_norm: 2.5151 - val_det_loss: 0.7660 - val_cls_loss: 0.5389 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8297
Epoch 36/50
21/21 [==============================] - 6s 296ms/step - det_loss: 0.6652 - cls_loss: 0.4305 - box_loss: 0.0047 - reg_l2_loss: 0.0637 - loss: 0.7289 - learning_rate: 0.0018 - gradient_norm: 2.3834 - val_det_loss: 0.7527 - val_cls_loss: 0.5288 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8164
Epoch 37/50
21/21 [==============================] - 7s 319ms/step - det_loss: 0.6420 - cls_loss: 0.4219 - box_loss: 0.0044 - reg_l2_loss: 0.0637 - loss: 0.7058 - learning_rate: 0.0015 - gradient_norm: 2.2842 - val_det_loss: 0.7558 - val_cls_loss: 0.5281 - val_box_loss: 0.0046 - val_reg_l2_loss: 0.0637 - val_loss: 0.8195
Epoch 38/50
21/21 [==============================] - 6s 288ms/step - det_loss: 0.6515 - cls_loss: 0.4270 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7153 - learning_rate: 0.0013 - gradient_norm: 2.4729 - val_det_loss: 0.7496 - val_cls_loss: 0.5206 - val_box_loss: 0.0046 - val_reg_l2_loss: 0.0637 - val_loss: 0.8134
Epoch 39/50
21/21 [==============================] - 6s 288ms/step - det_loss: 0.6499 - cls_loss: 0.4260 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7136 - learning_rate: 0.0011 - gradient_norm: 2.3206 - val_det_loss: 0.7411 - val_cls_loss: 0.5182 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8049
Epoch 40/50
21/21 [==============================] - 8s 378ms/step - det_loss: 0.6462 - cls_loss: 0.4175 - box_loss: 0.0046 - reg_l2_loss: 0.0637 - loss: 0.7099 - learning_rate: 9.0029e-04 - gradient_norm: 2.4006 - val_det_loss: 0.7520 - val_cls_loss: 0.5253 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8157
Epoch 41/50
21/21 [==============================] - 6s 280ms/step - det_loss: 0.6542 - cls_loss: 0.4267 - box_loss: 0.0046 - reg_l2_loss: 0.0637 - loss: 0.7179 - learning_rate: 7.2543e-04 - gradient_norm: 2.3908 - val_det_loss: 0.7583 - val_cls_loss: 0.5314 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8220
Epoch 42/50
21/21 [==============================] - 6s 311ms/step - det_loss: 0.6444 - cls_loss: 0.4262 - box_loss: 0.0044 - reg_l2_loss: 0.0637 - loss: 0.7081 - learning_rate: 5.6814e-04 - gradient_norm: 2.4907 - val_det_loss: 0.7401 - val_cls_loss: 0.5170 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8039
Epoch 43/50
21/21 [==============================] - 6s 278ms/step - det_loss: 0.6464 - cls_loss: 0.4179 - box_loss: 0.0046 - reg_l2_loss: 0.0637 - loss: 0.7101 - learning_rate: 4.2906e-04 - gradient_norm: 2.3817 - val_det_loss: 0.7376 - val_cls_loss: 0.5162 - val_box_loss: 0.0044 - val_reg_l2_loss: 0.0637 - val_loss: 0.8013
Epoch 44/50
21/21 [==============================] - 6s 280ms/step - det_loss: 0.6526 - cls_loss: 0.4254 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7163 - learning_rate: 3.0876e-04 - gradient_norm: 2.4726 - val_det_loss: 0.7359 - val_cls_loss: 0.5144 - val_box_loss: 0.0044 - val_reg_l2_loss: 0.0637 - val_loss: 0.7996
Epoch 45/50
21/21 [==============================] - 7s 366ms/step - det_loss: 0.6283 - cls_loss: 0.4161 - box_loss: 0.0042 - reg_l2_loss: 0.0637 - loss: 0.6920 - learning_rate: 2.0774e-04 - gradient_norm: 2.3249 - val_det_loss: 0.7380 - val_cls_loss: 0.5157 - val_box_loss: 0.0044 - val_reg_l2_loss: 0.0637 - val_loss: 0.8017
Epoch 46/50
21/21 [==============================] - 6s 315ms/step - det_loss: 0.6479 - cls_loss: 0.4240 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7116 - learning_rate: 1.2641e-04 - gradient_norm: 2.5665 - val_det_loss: 0.7422 - val_cls_loss: 0.5183 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8059
Epoch 47/50
21/21 [==============================] - 6s 279ms/step - det_loss: 0.6418 - cls_loss: 0.4172 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7055 - learning_rate: 6.5107e-05 - gradient_norm: 2.6628 - val_det_loss: 0.7433 - val_cls_loss: 0.5190 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8070
Epoch 48/50
21/21 [==============================] - 6s 284ms/step - det_loss: 0.6513 - cls_loss: 0.4284 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7150 - learning_rate: 2.4083e-05 - gradient_norm: 2.4614 - val_det_loss: 0.7444 - val_cls_loss: 0.5196 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8081
Epoch 49/50
21/21 [==============================] - 6s 290ms/step - det_loss: 0.6410 - cls_loss: 0.4188 - box_loss: 0.0044 - reg_l2_loss: 0.0637 - loss: 0.7048 - learning_rate: 3.5074e-06 - gradient_norm: 2.4491 - val_det_loss: 0.7436 - val_cls_loss: 0.5192 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8074
Epoch 50/50
21/21 [==============================] - 8s 369ms/step - det_loss: 0.6415 - cls_loss: 0.4182 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7052 - learning_rate: 3.4629e-06 - gradient_norm: 2.3632 - val_det_loss: 0.7439 - val_cls_loss: 0.5192 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8076

第 4 步:在测试数据上评估模型。

在使用训练数据集中的图像训练目标检测模型之后,使用测试数据集中的剩余 25 个图像来评估该模型在它以前从未见过的新数据上的性能。

由于默认批次大小为 64,因此只需一个步骤即可遍历测试数据集中的 25 个图像。

评估指标与 COCO 相同。

model.evaluate(test_data)
1/1 [==============================] - 6s 6s/step
{'AP': 0.1919611,
 'AP50': 0.34319812,
 'AP75': 0.1862729,
 'APs': -1.0,
 'APm': 0.3241761,
 'APl': 0.1923377,
 'ARmax1': 0.16314393,
 'ARmax10': 0.3077774,
 'ARmax100': 0.35298797,
 'ARs': -1.0,
 'ARm': 0.55,
 'ARl': 0.35037425,
 'AP_/Baked Goods': 0.008910891,
 'AP_/Salad': 0.5489583,
 'AP_/Cheese': 0.15109558,
 'AP_/Seafood': 0.05949282,
 'AP_/Tomato': 0.19134793}

第 5 步:导出为 TensorFlow Lite 模型

通过指定要将量化模型导出到的文件夹,将训练的目标检测模型导出为 TensorFlow Lite 格式。默认的训练后量化技术是全整数量化。

model.export(export_dir='.')
2022-08-31 02:23:43.606436: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
2022-08-31 02:24:08.865819: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'resample_p7/PartitionedCall' has 1 outputs but the _output_shapes attribute specifies shapes for 3 outputs. Output shapes may be inaccurate.
2022-08-31 02:24:16.164308: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:357] Ignored output_format.
2022-08-31 02:24:16.164355: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:360] Ignored drop_control_dependency.
fully_quantize: 0, inference_type: 6, input_inference_type: 3, output_inference_type: 0

第 6 步:评估 TensorFlow Lite 模型。

在导出为 TFLite 时,有几个因素可能会影响模型准确率:

  • 量化有助于将模型大小缩小为原来的四分一直,但代价是准确率会略微下降。
  • 原始 TensorFlow 模型使用每个类的非极大值抑制 (NMS) 进行后处理,而 TFLite 模型使用全局 NMS,速度快得多,但准确率较低。Keras 最多输出 100 个检测,而 Tflite 最多输出 25 个检测。

因此,您必须评估导出的 TFLite 模型,并将其准确率与原始 TensorFlow 模型进行比较。

model.evaluate_tflite('model.tflite', test_data)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
25/25 [==============================] - 67s 3s/step
{'AP': 0.18182021,
 'AP50': 0.31197006,
 'AP75': 0.19087325,
 'APs': -1.0,
 'APm': 0.2809451,
 'APl': 0.18239675,
 'ARmax1': 0.14099042,
 'ARmax10': 0.25025177,
 'ARmax100': 0.26311275,
 'ARs': -1.0,
 'ARm': 0.34166667,
 'ARl': 0.26235557,
 'AP_/Baked Goods': 0.0,
 'AP_/Salad': 0.5375744,
 'AP_/Cheese': 0.14260961,
 'AP_/Seafood': 0.05841584,
 'AP_/Tomato': 0.17050114}

您可以使用 Colab 的左侧边栏下载 TensorFlow Lite 模型文件。右键点击 model.tflite 文件,然后选择 Download 将其下载到本地计算机。

可以使用 TensorFlow Lite Task LibraryObjectDetector API 将此模型集成到 Android 或 iOS 应用中。

有关如何在工作应用中使用模型的更多详细信息,请参阅 TFLite 目标检测示例应用

注:Android Studio Model Binding 目前还不支持目标检测,请使用 TensorFlow Lite Task Library。

(可选)在您的图像上测试 TFLite 模型

您可以使用互联网上的图像测试训练后的 TFLite 模型。

  • 将下面的 INPUT_IMAGE_URL 替换为所需的输入图像。
  • 调整 DETECTION_THRESHOLD 以更改模型的灵敏度。较低的阈值意味着模型将拾取更多对象,但也会有更多的错误检测。与此同时,更高的阈值意味着该模型将只拾取它确信检测到的对象。

尽管目前需要一些样板代码才能在 Python 中运行模型,但将模型集成到移动应用中只需要几行代码。

Load the trained TFLite model and define some visualization functions

Run object detection and show the detection results

/tmpfs/tmp/ipykernel_81842/1212634245.py:10: DeprecationWarning: ANTIALIAS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead.
  im.thumbnail((512, 512), Image.ANTIALIAS)

png

(可选)针对 Edge TPU 编译

现在您已经有了量化的 EfficientDet Lite 模型,可以编译并部署到 Coral EdgeTPU

第 1 步:安装 EdgeTPU 编译器

 curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -

 echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list

 sudo apt-get update

 sudo apt-get install edgetpu-compiler
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  2537  100  2537    0     0   130k      0 --:--:-- --:--:-- --:--:--  130k
OK
deb https://packages.cloud.google.com/apt coral-edgetpu-stable main
Hit:1 http://us-central1.gce.archive.ubuntu.com/ubuntu focal InRelease
Hit:2 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates InRelease
Get:3 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB]
Hit:4 http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  InRelease
Get:6 https://packages.cloud.google.com/apt coral-edgetpu-stable InRelease [6722 B]
Hit:7 https://nvidia.github.io/libnvidia-container/stable/ubuntu18.04/amd64  InRelease
Hit:8 https://download.docker.com/linux/ubuntu focal InRelease
Get:9 https://nvidia.github.io/nvidia-container-runtime/stable/ubuntu18.04/amd64  InRelease [1481 B]
Get:10 https://nvidia.github.io/nvidia-docker/ubuntu18.04/amd64  InRelease [1474 B]
Hit:11 http://security.ubuntu.com/ubuntu focal-security InRelease
Ign:12 http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64  InRelease
Hit:13 http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64  Release
Hit:5 https://apt.llvm.org/focal llvm-toolchain-focal-14 InRelease
Hit:14 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease
Ign:15 https://packages.cloud.google.com/apt coral-edgetpu-stable/main amd64 Packages
Hit:16 http://ppa.launchpad.net/longsleep/golang-backports/ubuntu focal InRelease
Hit:17 http://ppa.launchpad.net/openjdk-r/ppa/ubuntu focal InRelease
Get:15 https://packages.cloud.google.com/apt coral-edgetpu-stable/main amd64 Packages [2317 B]
Fetched 120 kB in 1s (82.9 kB/s)




The following packages were automatically installed and are no longer required:
  libatasmart4 libblockdev-fs2 libblockdev-loop2 libblockdev-part-err2
  libblockdev-part2 libblockdev-swap2 libblockdev-utils2 libblockdev2 libnuma1
  libparted-fs-resize0
Use 'sudo apt autoremove' to remove them.
The following NEW packages will be installed:
  edgetpu-compiler
0 upgraded, 1 newly installed, 0 to remove and 258 not upgraded.
Need to get 7913 kB of archives.
After this operation, 31.2 MB of additional disk space will be used.
Get:1 https://packages.cloud.google.com/apt coral-edgetpu-stable/main amd64 edgetpu-compiler amd64 16.0 [7913 kB]
Fetched 7913 kB in 0s (30.0 MB/s)
Selecting previously unselected package edgetpu-compiler.
(Reading database ... 139823 files and directories currently installed.)
Preparing to unpack .../edgetpu-compiler_16.0_amd64.deb ...
Unpacking edgetpu-compiler (16.0) ...
Setting up edgetpu-compiler (16.0) ...
Processing triggers for libc-bin (2.31-0ubuntu9.7) ...

第 2 步:选择 Edge TPU 数量,然后编译

EdgeTPU 有 8MB 的 SRAM 用于缓存模型参数(更多信息)。这意味着对于大于 8MB 的模型,为了传递模型参数,推断时间将增加。避免这种情况的一种方式是模型流水线 - 将模型拆分成可以使用专用 EdgeTPU 的段。这可以显著改善延迟。

下表可用作要使用的 Edge TPU 数量的参考 - 由于中间张量无法放入片上内存,较大的模型将无法使用单个 TPU 编译。

模型架构 最低 TPU 数 建议 TPU 数
EfficientDet-Lite0 1 1
EfficientDet-Lite1 1 1
EfficientDet-Lite2 1 2
EfficientDet-Lite3 2 2
EfficientDet-Lite4 2 3

Edge TPU Compiler version 16.0.384591198
Started a compilation timeout timer of 180 seconds.

Model compiled successfully in 4407 ms.

Input model: model.tflite
Input size: 4.24MiB
Output model: model_edgetpu.tflite
Output size: 5.61MiB
On-chip memory used for caching model parameters: 4.24MiB
On-chip memory remaining for caching model parameters: 3.27MiB
Off-chip memory used for streaming uncached model parameters: 0.00B
Number of Edge TPU subgraphs: 1
Total number of operations: 267
Operation log: model_edgetpu.log

Model successfully compiled but not all operations are supported by the Edge TPU. A percentage of the model will instead run on the CPU, which is slower. If possible, consider updating your model to use only operations supported by the Edge TPU. For details, visit g.co/coral/model-reqs.
Number of operations that will run on Edge TPU: 264
Number of operations that will run on CPU: 3
See the operation log file for individual operation details.
Compilation child process completed within timeout period.
Compilation succeeded!

第 3 步:下载并运行模型

经过编译后,现在可以在 EdgeTPU 上运行模型以进行目标检测。首先,使用 Colab 的左侧边栏下载编译后的 TensorFlow Lite 模型文件。右键点击 model_edgetpu.tflite 文件,然后选择 Download 将其下载到本地计算机。

现在,您可以用您喜欢的方式运行模型。检测示例包括:

高级用法

本部分介绍高级用法主题,如调整模型和训练超参数。

加载数据集

加载您自己的数据

您可以上传您自己的数据集以完成本教程。请使用 Colab 的左侧边栏上传您的数据集。

上传文件

如果您不想将数据集上传到云端,也可以按照指南在本地运行库。

使用不同的数据格式加载数据

Model Maker 库还支持 object_detector.DataLoader.from_pascal_voc 方法来加载 PASCAL VOC 格式的数据。makesense.aiLabelImg 工具可以注解图像并将注解保存为 PASCAL VOC 数据格式的 XML 文件:

object_detector.DataLoader.from_pascal_voc(image_dir, annotations_dir, label_map={1: "person", 2: "notperson"})

自定义 EfficientDet 模型超参数

可以调整的模型和训练流水线参数包括:

  • model_dir:模型检查点文件的保存位置。如果未设置,将使用临时目录。
  • steps_per_execution:每个训练执行的步骤数。
  • moving_average_decay浮点。用于维护训练参数的移动平均值的衰减。
  • var_freeze_expr:映射待冻结变量的前缀名称的正则表达式,表示在训练期间保持不变。更具体地说,在代码库中使用 re.match(var_freeze_expr, variable_name) 来映射要冻结的变量。
  • tflite_max_detections:整数,默认为 25。TFLite 模型中的最大输出检测数。
  • strategy:指定使用哪种分布策略的字符串。可接受的值为 'tpu'、'gpus'、None。'tpu' 是指使用 TPUStrategy。'gpus' 是指为多 GPU 使用 MirroredStrategy。如果为 None,则使用 OneDeviceStrategy 的 TF 默认值。
  • tpu:用于训练的 Cloud TPU。这应该是创建 Cloud TPU 时使用的名称,或者是 grpc://ip.address.of.tpu:8470 网址。
  • use_xla:即使策略不是 TPU,也使用 XLA。如果策略是 TPU,则始终使用 XLA,并且此标志无效。
  • profile:启用配置文件模式。
  • debug:启用调试模式。

其他可以调整的参数如 hparams_config.py 中所示。

例如,您可以设置 var_freeze_expr='efficientnet',这将冻结名称前缀为 efficientnet 的变量(默认为 '(efficientnet|fpn_cells|resample_p6)')。这允许模型冻结不可训练的变量,并在训练过程中保持它们的值不变。

spec = model_spec.get('efficientdet_lite0')
spec.config.var_freeze_expr = 'efficientnet'

更改模型架构

您可以通过更改 model_spec 来更改模型架构。例如,将 model_spec 更改为 EfficientDet-Lite4 模型。

spec = model_spec.get('efficientdet_lite4')

调整训练超参数

create 函数是 Model Maker 库用于创建模型的驱动函数。model_spec 参数定义模型规范。目前支持 object_detector.EfficientDetSpec 类。create 函数包括以下步骤:

  1. 根据 model_spec 创建用于目标检测的模型。
  2. 训练模型。默认周期和默认批次大小由 model_spec 对象中的 epochsbatch_size 变量设置。您还可以调整训练超参数,如影响模型准确率的 epochsbatch_size。例如,
  • epochs:整数,默认为 50。更多周期可以获得更好的准确率,但可以会导致过拟合。
  • batch_size:整数,默认为 64。一个训练步骤中要使用的样本数。
  • train_whole_model:布尔值,默认为 False。如果为 true,则训练整个模型。否则,只训练不匹配 var_freeze_expr 的层。

例如,您可以使用较少的周期进行训练,并且只使用头层。您可以增加周期数以获得更好的效果。

model = object_detector.create(train_data, model_spec=spec, epochs=10, validation_data=validation_data)

导出为不同格式

导出格式可以是以下列表中的一个或多个:

默认情况下,它仅导出包含模型元数据的 TensorFlow Lite 模型文件,以便以后在设备端机器学习应用中使用。标签文件嵌入在元数据中。

在许多设备端机器学习应用中,模型大小是一个重要因素。因此,建议您量化模型以使其更小并可能加快运行速度。对于 EfficientDet-Lite 模型,默认使用全整数量化来量化模型。请参阅训练后量化了解详细信息。

model.export(export_dir='.')

您还可以选择导出与模型相关的其他文件,以便更好地进行检查。例如,按如下方式同时导出保存的模型和标签文件:

model.export(export_dir='.', export_format=[ExportFormat.SAVED_MODEL, ExportFormat.LABEL])

在 TensorFlow Lite 模型上自定义训练后量化

训练后量化是一种转换技术,可以缩减模型大小并缩短推断延迟,同时改善 CPU 和硬件加速器推断速度,且几乎不会降低模型准确率。因此,它被广泛用于优化模型。

Model Maker 库在导出模型时会应用默认的训练后量化技术。如果您想自定义训练后量化,Model Maker 也支持使用 QuantizationConfig 的多个训练后量化选项。我们以 float16 量化为例。首先,定义量化配置。

config = QuantizationConfig.for_float16()

然后,我们使用此配置导出 TensorFlow Lite 模型。

model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)

阅读更多

您可以阅读我们的目标检测示例以了解技术细节。如需了解更多信息,请参阅: