在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 Github 上查看源代码 | 下载笔记本 |
在此 CoLab 笔记本中,您将学习如何使用 TensorFlow Lite Model Maker 库来训练能够在移动设备上检测图像中的沙拉的自定义目标检测模型。
Model Maker 库使用迁移学习来简化使用自定义数据集训练 TensorFlow Lite 模型的过程。使用您自己的自定义数据集重新训练 TensorFlow Lite 模型可以减少所需的训练数据量,并将缩短训练时间。
您将使用公开可用的 Salad 数据集,该数据集创建自 Open Images Dataset V4。
该数据集中的每个图像都包含标记为以下其中一类的对象:
- 烘焙食品
- 奶酪
- 沙拉
- 海鲜
- 番茄
该数据集包含指定每个对象所在位置的边界框以及对象的标签。
以下是数据集中的示例图像:
先决条件
安装所需的软件包
首先安装所需的包,包括来自 GitHub 仓库的 Model Maker 软件包和将用于评估的 pycocotools 库。
sudo apt -y install libportaudio2
pip install -q --use-deprecated=legacy-resolver tflite-model-maker
pip install -q pycocotools
pip install -q opencv-python-headless==4.1.2.30
pip uninstall -y tensorflow && pip install -q tensorflow==2.8.0
导入所需的软件包。
import numpy as np
import os
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)
准备数据集
在这里,您将使用与 AutoML 快速入门相同的数据集。
Salads 数据集可从以下地址获得:gs://cloud-ml-data/img/openimage/csv/salads_ml_use.csv
.
其中包含 175 个用于训练的图像,25 个用于验证的图像,以及 25 个用于测试的图像。数据集有五个类:Salad
、Seafood
、Tomato
、Baked goods
、Cheese
。
数据集以 CSV 格式提供:
TRAINING,gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg,Salad,0.0,0.0954,,,0.977,0.957,,
VALIDATION,gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg,Seafood,0.0154,0.1538,,,1.0,0.802,,
TEST,gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg,Tomato,0.0,0.655,,,0.231,0.839,,
- 每一行对应于一个定位在较大图像中的对象,每个对象被专门指定为测试、训练或验证数据。在本笔记本的后面阶段,您将了解关于这么做的意义的更多信息。
- 这里包含的三行表示同一图像中的三个不同对象,可从以下地址获得:
gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg
。 - 每一行都有不同的标签:
Salad
、Seafood
、Tomato
等。 - 使用左上角和右下角顶点为每个图像指定边界框。
以下是这三行的可视化效果:
如果您想了解有关如何准备您自己的 CSV 文件以及创建有效数据集的最低要求的更多信息,请参阅准备您的训练数据指南了解更多详细信息。
如果您是 Google Cloud 的新用户,您可能想知道 gs://
网址是什么意思。它们是存储在 Google Cloud Storage (GCS) 上的文件的网址。如果您在 GCS 上公开您的文件或验证您的客户端,Model Maker 可以像读取本地文件一样读取这些文件。
然而,您不需要将图片保存在 Google Cloud 上就可以使用 Model Maker。您可以在 CSV 文件中使用本地路径,Model Maker 将正常工作。
快速入门
训练目标检测模型有六个步骤:
第 1 步:选择目标检测模型架构。
本教程使用 EfficientDet-Lite0 模型。EfficientDet-Lite[0-4] 是一系列移动/物联网友好的目标检测模型,派生自 EfficientDet 架构。
以下是每种 EfficientDet-Lite 模型之间的性能对比。
模型架构 | 大小 (MB)* | 延迟 (ms)** | 平均精度*** |
---|---|---|---|
EfficientDet-Lite0 | 4.4 | 37 | 25.69% |
EfficientDet-Lite1 | 5.8 | 49 | 30.55% |
EfficientDet-Lite2 | 7.2 | 69 | 33.97% |
EfficientDet-Lite3 | 11.4 | 116 | 37.70% |
EfficientDet-Lite4 | 19.9 | 260 | 41.96% |
* 整数量化模型的大小。
** 延迟在使用 4 个 CPU 线程的 Pixel 4 上测得。
*** 平均精度是 COCO 2017 验证数据集上的 mAP(平均精度均值)。
spec = model_spec.get('efficientdet_lite0')
第 2 步:加载数据集。
Model Maker 将接收 CSV 格式的输入数据。使用 object_detector.DataLoader.from_csv
方法加载数据集,并将其分割为训练、验证和测试图像。
- 训练图像:这些图像用于训练目标检测模型识别沙拉成分。
- 验证图像:这些是模型在训练过程中没有见过的图像。您将使用它们来决定何时应停止训练,以避免过拟合。
- 测试图像:这些图像用于评估最终模型的性能。
您可以直接从 Google Cloud Storage 加载 CSV 文件,但不需要在 Google Cloud 上保留您的图像来使用 Model Maker。您可以在计算机上指定一个本地 CSV 文件,Model Maker 即可正常工作。
train_data, validation_data, test_data = object_detector.DataLoader.from_csv('gs://cloud-ml-data/img/openimage/csv/salads_ml_use.csv')
第 3 步:用训练数据训练 TensorFlow 模型。
- EfficientDet-Lite0 模型默认使用
epochs = 50
,这意味着它将对训练数据集进行 50 次遍历。您可以在训练期间查看验证准确率,并提前停止,以避免过拟合。 - 在此处设置
batch_size = 8
,这样您将看到,遍历训练数据集中的 175 个图像需要 21 个步骤。 - 设置
train_whole_model=True
可以对整个模型进行微调,而不仅仅是训练头层来提高准确率。代价是训练模型可能需要更长的时间。
model = object_detector.create(train_data, model_spec=spec, batch_size=8, train_whole_model=True, validation_data=validation_data)
Epoch 1/50 21/21 [==============================] - 41s 432ms/step - det_loss: 1.7675 - cls_loss: 1.1338 - box_loss: 0.0127 - reg_l2_loss: 0.0635 - loss: 1.8310 - learning_rate: 0.0090 - gradient_norm: 0.7692 - val_det_loss: 1.6300 - val_cls_loss: 1.0870 - val_box_loss: 0.0109 - val_reg_l2_loss: 0.0635 - val_loss: 1.6936 Epoch 2/50 21/21 [==============================] - 6s 284ms/step - det_loss: 1.6326 - cls_loss: 1.0827 - box_loss: 0.0110 - reg_l2_loss: 0.0635 - loss: 1.6961 - learning_rate: 0.0100 - gradient_norm: 0.9392 - val_det_loss: 1.4099 - val_cls_loss: 0.9169 - val_box_loss: 0.0099 - val_reg_l2_loss: 0.0635 - val_loss: 1.4735 Epoch 3/50 21/21 [==============================] - 6s 288ms/step - det_loss: 1.4586 - cls_loss: 0.9606 - box_loss: 0.0100 - reg_l2_loss: 0.0635 - loss: 1.5221 - learning_rate: 0.0099 - gradient_norm: 1.8510 - val_det_loss: 1.4634 - val_cls_loss: 1.0030 - val_box_loss: 0.0092 - val_reg_l2_loss: 0.0635 - val_loss: 1.5269 Epoch 4/50 21/21 [==============================] - 6s 313ms/step - det_loss: 1.2882 - cls_loss: 0.8228 - box_loss: 0.0093 - reg_l2_loss: 0.0636 - loss: 1.3517 - learning_rate: 0.0099 - gradient_norm: 2.2037 - val_det_loss: 1.4783 - val_cls_loss: 1.0472 - val_box_loss: 0.0086 - val_reg_l2_loss: 0.0636 - val_loss: 1.5418 Epoch 5/50 21/21 [==============================] - 12s 573ms/step - det_loss: 1.1467 - cls_loss: 0.7332 - box_loss: 0.0083 - reg_l2_loss: 0.0636 - loss: 1.2103 - learning_rate: 0.0098 - gradient_norm: 1.8869 - val_det_loss: 1.1540 - val_cls_loss: 0.7317 - val_box_loss: 0.0084 - val_reg_l2_loss: 0.0636 - val_loss: 1.2176 Epoch 6/50 21/21 [==============================] - 6s 295ms/step - det_loss: 1.0630 - cls_loss: 0.6670 - box_loss: 0.0079 - reg_l2_loss: 0.0636 - loss: 1.1266 - learning_rate: 0.0097 - gradient_norm: 1.8586 - val_det_loss: 1.0742 - val_cls_loss: 0.6978 - val_box_loss: 0.0075 - val_reg_l2_loss: 0.0636 - val_loss: 1.1378 Epoch 7/50 21/21 [==============================] - 7s 318ms/step - det_loss: 1.0248 - cls_loss: 0.6464 - box_loss: 0.0076 - reg_l2_loss: 0.0636 - loss: 1.0884 - learning_rate: 0.0096 - gradient_norm: 1.7937 - val_det_loss: 0.9156 - val_cls_loss: 0.5667 - val_box_loss: 0.0070 - val_reg_l2_loss: 0.0636 - val_loss: 0.9792 Epoch 8/50 21/21 [==============================] - 6s 288ms/step - det_loss: 0.9946 - cls_loss: 0.6278 - box_loss: 0.0073 - reg_l2_loss: 0.0636 - loss: 1.0582 - learning_rate: 0.0094 - gradient_norm: 1.8340 - val_det_loss: 0.8883 - val_cls_loss: 0.5422 - val_box_loss: 0.0069 - val_reg_l2_loss: 0.0636 - val_loss: 0.9519 Epoch 9/50 21/21 [==============================] - 6s 283ms/step - det_loss: 0.9602 - cls_loss: 0.6179 - box_loss: 0.0068 - reg_l2_loss: 0.0636 - loss: 1.0238 - learning_rate: 0.0093 - gradient_norm: 1.9670 - val_det_loss: 0.8432 - val_cls_loss: 0.5169 - val_box_loss: 0.0065 - val_reg_l2_loss: 0.0636 - val_loss: 0.9068 Epoch 10/50 21/21 [==============================] - 7s 366ms/step - det_loss: 0.9316 - cls_loss: 0.5941 - box_loss: 0.0067 - reg_l2_loss: 0.0636 - loss: 0.9952 - learning_rate: 0.0091 - gradient_norm: 2.0556 - val_det_loss: 0.8252 - val_cls_loss: 0.5355 - val_box_loss: 0.0058 - val_reg_l2_loss: 0.0636 - val_loss: 0.8889 Epoch 11/50 21/21 [==============================] - 6s 314ms/step - det_loss: 0.9019 - cls_loss: 0.5770 - box_loss: 0.0065 - reg_l2_loss: 0.0636 - loss: 0.9656 - learning_rate: 0.0089 - gradient_norm: 1.9285 - val_det_loss: 0.7875 - val_cls_loss: 0.5036 - val_box_loss: 0.0057 - val_reg_l2_loss: 0.0636 - val_loss: 0.8511 Epoch 12/50 21/21 [==============================] - 6s 283ms/step - det_loss: 0.8818 - cls_loss: 0.5675 - box_loss: 0.0063 - reg_l2_loss: 0.0636 - loss: 0.9455 - learning_rate: 0.0087 - gradient_norm: 1.9203 - val_det_loss: 0.8129 - val_cls_loss: 0.5313 - val_box_loss: 0.0056 - val_reg_l2_loss: 0.0636 - val_loss: 0.8766 Epoch 13/50 21/21 [==============================] - 6s 289ms/step - det_loss: 0.8463 - cls_loss: 0.5471 - box_loss: 0.0060 - reg_l2_loss: 0.0636 - loss: 0.9099 - learning_rate: 0.0085 - gradient_norm: 1.9731 - val_det_loss: 0.8389 - val_cls_loss: 0.5602 - val_box_loss: 0.0056 - val_reg_l2_loss: 0.0636 - val_loss: 0.9026 Epoch 14/50 21/21 [==============================] - 6s 282ms/step - det_loss: 0.8532 - cls_loss: 0.5572 - box_loss: 0.0059 - reg_l2_loss: 0.0637 - loss: 0.9169 - learning_rate: 0.0082 - gradient_norm: 2.2388 - val_det_loss: 0.7846 - val_cls_loss: 0.5136 - val_box_loss: 0.0054 - val_reg_l2_loss: 0.0637 - val_loss: 0.8483 Epoch 15/50 21/21 [==============================] - 8s 389ms/step - det_loss: 0.8212 - cls_loss: 0.5365 - box_loss: 0.0057 - reg_l2_loss: 0.0637 - loss: 0.8849 - learning_rate: 0.0080 - gradient_norm: 2.1374 - val_det_loss: 0.7583 - val_cls_loss: 0.4988 - val_box_loss: 0.0052 - val_reg_l2_loss: 0.0637 - val_loss: 0.8219 Epoch 16/50 21/21 [==============================] - 6s 289ms/step - det_loss: 0.8107 - cls_loss: 0.5229 - box_loss: 0.0058 - reg_l2_loss: 0.0637 - loss: 0.8744 - learning_rate: 0.0077 - gradient_norm: 2.1675 - val_det_loss: 0.8215 - val_cls_loss: 0.5593 - val_box_loss: 0.0052 - val_reg_l2_loss: 0.0637 - val_loss: 0.8852 Epoch 17/50 21/21 [==============================] - 6s 286ms/step - det_loss: 0.7978 - cls_loss: 0.5132 - box_loss: 0.0057 - reg_l2_loss: 0.0637 - loss: 0.8615 - learning_rate: 0.0075 - gradient_norm: 2.3357 - val_det_loss: 0.8640 - val_cls_loss: 0.5992 - val_box_loss: 0.0053 - val_reg_l2_loss: 0.0637 - val_loss: 0.9277 Epoch 18/50 21/21 [==============================] - 6s 293ms/step - det_loss: 0.7642 - cls_loss: 0.5042 - box_loss: 0.0052 - reg_l2_loss: 0.0637 - loss: 0.8278 - learning_rate: 0.0072 - gradient_norm: 2.3080 - val_det_loss: 0.7775 - val_cls_loss: 0.5312 - val_box_loss: 0.0049 - val_reg_l2_loss: 0.0637 - val_loss: 0.8412 Epoch 19/50 21/21 [==============================] - 6s 284ms/step - det_loss: 0.7664 - cls_loss: 0.4978 - box_loss: 0.0054 - reg_l2_loss: 0.0637 - loss: 0.8301 - learning_rate: 0.0069 - gradient_norm: 2.1915 - val_det_loss: 0.7483 - val_cls_loss: 0.5196 - val_box_loss: 0.0046 - val_reg_l2_loss: 0.0637 - val_loss: 0.8120 Epoch 20/50 21/21 [==============================] - 8s 388ms/step - det_loss: 0.7418 - cls_loss: 0.4831 - box_loss: 0.0052 - reg_l2_loss: 0.0637 - loss: 0.8055 - learning_rate: 0.0066 - gradient_norm: 2.2534 - val_det_loss: 0.7673 - val_cls_loss: 0.5250 - val_box_loss: 0.0048 - val_reg_l2_loss: 0.0637 - val_loss: 0.8310 Epoch 21/50 21/21 [==============================] - 6s 280ms/step - det_loss: 0.8007 - cls_loss: 0.5087 - box_loss: 0.0058 - reg_l2_loss: 0.0637 - loss: 0.8644 - learning_rate: 0.0063 - gradient_norm: 2.5656 - val_det_loss: 0.7869 - val_cls_loss: 0.5317 - val_box_loss: 0.0051 - val_reg_l2_loss: 0.0637 - val_loss: 0.8506 Epoch 22/50 21/21 [==============================] - 6s 295ms/step - det_loss: 0.7374 - cls_loss: 0.4861 - box_loss: 0.0050 - reg_l2_loss: 0.0637 - loss: 0.8011 - learning_rate: 0.0060 - gradient_norm: 2.2847 - val_det_loss: 0.8170 - val_cls_loss: 0.5612 - val_box_loss: 0.0051 - val_reg_l2_loss: 0.0637 - val_loss: 0.8807 Epoch 23/50 21/21 [==============================] - 6s 283ms/step - det_loss: 0.7384 - cls_loss: 0.4833 - box_loss: 0.0051 - reg_l2_loss: 0.0637 - loss: 0.8021 - learning_rate: 0.0056 - gradient_norm: 2.2279 - val_det_loss: 0.8068 - val_cls_loss: 0.5482 - val_box_loss: 0.0052 - val_reg_l2_loss: 0.0637 - val_loss: 0.8705 Epoch 24/50 21/21 [==============================] - 7s 322ms/step - det_loss: 0.7370 - cls_loss: 0.4840 - box_loss: 0.0051 - reg_l2_loss: 0.0637 - loss: 0.8007 - learning_rate: 0.0053 - gradient_norm: 2.3582 - val_det_loss: 0.8889 - val_cls_loss: 0.6238 - val_box_loss: 0.0053 - val_reg_l2_loss: 0.0637 - val_loss: 0.9526 Epoch 25/50 21/21 [==============================] - 7s 364ms/step - det_loss: 0.7079 - cls_loss: 0.4719 - box_loss: 0.0047 - reg_l2_loss: 0.0637 - loss: 0.7716 - learning_rate: 0.0050 - gradient_norm: 2.6073 - val_det_loss: 0.8224 - val_cls_loss: 0.5445 - val_box_loss: 0.0056 - val_reg_l2_loss: 0.0637 - val_loss: 0.8861 Epoch 26/50 21/21 [==============================] - 6s 288ms/step - det_loss: 0.7276 - cls_loss: 0.4669 - box_loss: 0.0052 - reg_l2_loss: 0.0637 - loss: 0.7913 - learning_rate: 0.0047 - gradient_norm: 2.4797 - val_det_loss: 0.8051 - val_cls_loss: 0.5411 - val_box_loss: 0.0053 - val_reg_l2_loss: 0.0637 - val_loss: 0.8688 Epoch 27/50 21/21 [==============================] - 6s 290ms/step - det_loss: 0.7037 - cls_loss: 0.4558 - box_loss: 0.0050 - reg_l2_loss: 0.0637 - loss: 0.7674 - learning_rate: 0.0044 - gradient_norm: 2.3431 - val_det_loss: 0.8115 - val_cls_loss: 0.5432 - val_box_loss: 0.0054 - val_reg_l2_loss: 0.0637 - val_loss: 0.8752 Epoch 28/50 21/21 [==============================] - 6s 308ms/step - det_loss: 0.7116 - cls_loss: 0.4590 - box_loss: 0.0051 - reg_l2_loss: 0.0637 - loss: 0.7753 - learning_rate: 0.0040 - gradient_norm: 2.7340 - val_det_loss: 0.8266 - val_cls_loss: 0.5705 - val_box_loss: 0.0051 - val_reg_l2_loss: 0.0637 - val_loss: 0.8904 Epoch 29/50 21/21 [==============================] - 6s 278ms/step - det_loss: 0.6887 - cls_loss: 0.4471 - box_loss: 0.0048 - reg_l2_loss: 0.0637 - loss: 0.7524 - learning_rate: 0.0037 - gradient_norm: 2.4508 - val_det_loss: 0.7780 - val_cls_loss: 0.5375 - val_box_loss: 0.0048 - val_reg_l2_loss: 0.0637 - val_loss: 0.8417 Epoch 30/50 21/21 [==============================] - 7s 359ms/step - det_loss: 0.6970 - cls_loss: 0.4540 - box_loss: 0.0049 - reg_l2_loss: 0.0637 - loss: 0.7607 - learning_rate: 0.0034 - gradient_norm: 2.4654 - val_det_loss: 0.7527 - val_cls_loss: 0.5348 - val_box_loss: 0.0044 - val_reg_l2_loss: 0.0637 - val_loss: 0.8164 Epoch 31/50 21/21 [==============================] - 6s 285ms/step - det_loss: 0.6573 - cls_loss: 0.4318 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7210 - learning_rate: 0.0031 - gradient_norm: 2.3340 - val_det_loss: 0.7645 - val_cls_loss: 0.5354 - val_box_loss: 0.0046 - val_reg_l2_loss: 0.0637 - val_loss: 0.8282 Epoch 32/50 21/21 [==============================] - 6s 279ms/step - det_loss: 0.6636 - cls_loss: 0.4338 - box_loss: 0.0046 - reg_l2_loss: 0.0637 - loss: 0.7273 - learning_rate: 0.0028 - gradient_norm: 2.3318 - val_det_loss: 0.7956 - val_cls_loss: 0.5594 - val_box_loss: 0.0047 - val_reg_l2_loss: 0.0637 - val_loss: 0.8593 Epoch 33/50 21/21 [==============================] - 6s 316ms/step - det_loss: 0.6770 - cls_loss: 0.4470 - box_loss: 0.0046 - reg_l2_loss: 0.0637 - loss: 0.7407 - learning_rate: 0.0025 - gradient_norm: 2.5163 - val_det_loss: 0.7844 - val_cls_loss: 0.5614 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8482 Epoch 34/50 21/21 [==============================] - 6s 287ms/step - det_loss: 0.6811 - cls_loss: 0.4467 - box_loss: 0.0047 - reg_l2_loss: 0.0637 - loss: 0.7449 - learning_rate: 0.0023 - gradient_norm: 2.5778 - val_det_loss: 0.7602 - val_cls_loss: 0.5361 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8240 Epoch 35/50 21/21 [==============================] - 8s 370ms/step - det_loss: 0.6745 - cls_loss: 0.4449 - box_loss: 0.0046 - reg_l2_loss: 0.0637 - loss: 0.7382 - learning_rate: 0.0020 - gradient_norm: 2.5151 - val_det_loss: 0.7660 - val_cls_loss: 0.5389 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8297 Epoch 36/50 21/21 [==============================] - 6s 296ms/step - det_loss: 0.6652 - cls_loss: 0.4305 - box_loss: 0.0047 - reg_l2_loss: 0.0637 - loss: 0.7289 - learning_rate: 0.0018 - gradient_norm: 2.3834 - val_det_loss: 0.7527 - val_cls_loss: 0.5288 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8164 Epoch 37/50 21/21 [==============================] - 7s 319ms/step - det_loss: 0.6420 - cls_loss: 0.4219 - box_loss: 0.0044 - reg_l2_loss: 0.0637 - loss: 0.7058 - learning_rate: 0.0015 - gradient_norm: 2.2842 - val_det_loss: 0.7558 - val_cls_loss: 0.5281 - val_box_loss: 0.0046 - val_reg_l2_loss: 0.0637 - val_loss: 0.8195 Epoch 38/50 21/21 [==============================] - 6s 288ms/step - det_loss: 0.6515 - cls_loss: 0.4270 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7153 - learning_rate: 0.0013 - gradient_norm: 2.4729 - val_det_loss: 0.7496 - val_cls_loss: 0.5206 - val_box_loss: 0.0046 - val_reg_l2_loss: 0.0637 - val_loss: 0.8134 Epoch 39/50 21/21 [==============================] - 6s 288ms/step - det_loss: 0.6499 - cls_loss: 0.4260 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7136 - learning_rate: 0.0011 - gradient_norm: 2.3206 - val_det_loss: 0.7411 - val_cls_loss: 0.5182 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8049 Epoch 40/50 21/21 [==============================] - 8s 378ms/step - det_loss: 0.6462 - cls_loss: 0.4175 - box_loss: 0.0046 - reg_l2_loss: 0.0637 - loss: 0.7099 - learning_rate: 9.0029e-04 - gradient_norm: 2.4006 - val_det_loss: 0.7520 - val_cls_loss: 0.5253 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8157 Epoch 41/50 21/21 [==============================] - 6s 280ms/step - det_loss: 0.6542 - cls_loss: 0.4267 - box_loss: 0.0046 - reg_l2_loss: 0.0637 - loss: 0.7179 - learning_rate: 7.2543e-04 - gradient_norm: 2.3908 - val_det_loss: 0.7583 - val_cls_loss: 0.5314 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8220 Epoch 42/50 21/21 [==============================] - 6s 311ms/step - det_loss: 0.6444 - cls_loss: 0.4262 - box_loss: 0.0044 - reg_l2_loss: 0.0637 - loss: 0.7081 - learning_rate: 5.6814e-04 - gradient_norm: 2.4907 - val_det_loss: 0.7401 - val_cls_loss: 0.5170 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8039 Epoch 43/50 21/21 [==============================] - 6s 278ms/step - det_loss: 0.6464 - cls_loss: 0.4179 - box_loss: 0.0046 - reg_l2_loss: 0.0637 - loss: 0.7101 - learning_rate: 4.2906e-04 - gradient_norm: 2.3817 - val_det_loss: 0.7376 - val_cls_loss: 0.5162 - val_box_loss: 0.0044 - val_reg_l2_loss: 0.0637 - val_loss: 0.8013 Epoch 44/50 21/21 [==============================] - 6s 280ms/step - det_loss: 0.6526 - cls_loss: 0.4254 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7163 - learning_rate: 3.0876e-04 - gradient_norm: 2.4726 - val_det_loss: 0.7359 - val_cls_loss: 0.5144 - val_box_loss: 0.0044 - val_reg_l2_loss: 0.0637 - val_loss: 0.7996 Epoch 45/50 21/21 [==============================] - 7s 366ms/step - det_loss: 0.6283 - cls_loss: 0.4161 - box_loss: 0.0042 - reg_l2_loss: 0.0637 - loss: 0.6920 - learning_rate: 2.0774e-04 - gradient_norm: 2.3249 - val_det_loss: 0.7380 - val_cls_loss: 0.5157 - val_box_loss: 0.0044 - val_reg_l2_loss: 0.0637 - val_loss: 0.8017 Epoch 46/50 21/21 [==============================] - 6s 315ms/step - det_loss: 0.6479 - cls_loss: 0.4240 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7116 - learning_rate: 1.2641e-04 - gradient_norm: 2.5665 - val_det_loss: 0.7422 - val_cls_loss: 0.5183 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8059 Epoch 47/50 21/21 [==============================] - 6s 279ms/step - det_loss: 0.6418 - cls_loss: 0.4172 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7055 - learning_rate: 6.5107e-05 - gradient_norm: 2.6628 - val_det_loss: 0.7433 - val_cls_loss: 0.5190 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8070 Epoch 48/50 21/21 [==============================] - 6s 284ms/step - det_loss: 0.6513 - cls_loss: 0.4284 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7150 - learning_rate: 2.4083e-05 - gradient_norm: 2.4614 - val_det_loss: 0.7444 - val_cls_loss: 0.5196 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8081 Epoch 49/50 21/21 [==============================] - 6s 290ms/step - det_loss: 0.6410 - cls_loss: 0.4188 - box_loss: 0.0044 - reg_l2_loss: 0.0637 - loss: 0.7048 - learning_rate: 3.5074e-06 - gradient_norm: 2.4491 - val_det_loss: 0.7436 - val_cls_loss: 0.5192 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8074 Epoch 50/50 21/21 [==============================] - 8s 369ms/step - det_loss: 0.6415 - cls_loss: 0.4182 - box_loss: 0.0045 - reg_l2_loss: 0.0637 - loss: 0.7052 - learning_rate: 3.4629e-06 - gradient_norm: 2.3632 - val_det_loss: 0.7439 - val_cls_loss: 0.5192 - val_box_loss: 0.0045 - val_reg_l2_loss: 0.0637 - val_loss: 0.8076
第 4 步:在测试数据上评估模型。
在使用训练数据集中的图像训练目标检测模型之后,使用测试数据集中的剩余 25 个图像来评估该模型在它以前从未见过的新数据上的性能。
由于默认批次大小为 64,因此只需一个步骤即可遍历测试数据集中的 25 个图像。
评估指标与 COCO 相同。
model.evaluate(test_data)
1/1 [==============================] - 6s 6s/step {'AP': 0.1919611, 'AP50': 0.34319812, 'AP75': 0.1862729, 'APs': -1.0, 'APm': 0.3241761, 'APl': 0.1923377, 'ARmax1': 0.16314393, 'ARmax10': 0.3077774, 'ARmax100': 0.35298797, 'ARs': -1.0, 'ARm': 0.55, 'ARl': 0.35037425, 'AP_/Baked Goods': 0.008910891, 'AP_/Salad': 0.5489583, 'AP_/Cheese': 0.15109558, 'AP_/Seafood': 0.05949282, 'AP_/Tomato': 0.19134793}
第 5 步:导出为 TensorFlow Lite 模型
通过指定要将量化模型导出到的文件夹,将训练的目标检测模型导出为 TensorFlow Lite 格式。默认的训练后量化技术是全整数量化。
model.export(export_dir='.')
2022-08-31 02:23:43.606436: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. 2022-08-31 02:24:08.865819: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'resample_p7/PartitionedCall' has 1 outputs but the _output_shapes attribute specifies shapes for 3 outputs. Output shapes may be inaccurate. 2022-08-31 02:24:16.164308: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:357] Ignored output_format. 2022-08-31 02:24:16.164355: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:360] Ignored drop_control_dependency. fully_quantize: 0, inference_type: 6, input_inference_type: 3, output_inference_type: 0
第 6 步:评估 TensorFlow Lite 模型。
在导出为 TFLite 时,有几个因素可能会影响模型准确率:
- 量化有助于将模型大小缩小为原来的四分一直,但代价是准确率会略微下降。
- 原始 TensorFlow 模型使用每个类的非极大值抑制 (NMS) 进行后处理,而 TFLite 模型使用全局 NMS,速度快得多,但准确率较低。Keras 最多输出 100 个检测,而 Tflite 最多输出 25 个检测。
因此,您必须评估导出的 TFLite 模型,并将其准确率与原始 TensorFlow 模型进行比较。
model.evaluate_tflite('model.tflite', test_data)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU. 25/25 [==============================] - 67s 3s/step {'AP': 0.18182021, 'AP50': 0.31197006, 'AP75': 0.19087325, 'APs': -1.0, 'APm': 0.2809451, 'APl': 0.18239675, 'ARmax1': 0.14099042, 'ARmax10': 0.25025177, 'ARmax100': 0.26311275, 'ARs': -1.0, 'ARm': 0.34166667, 'ARl': 0.26235557, 'AP_/Baked Goods': 0.0, 'AP_/Salad': 0.5375744, 'AP_/Cheese': 0.14260961, 'AP_/Seafood': 0.05841584, 'AP_/Tomato': 0.17050114}
您可以使用 Colab 的左侧边栏下载 TensorFlow Lite 模型文件。右键点击 model.tflite
文件,然后选择 Download
将其下载到本地计算机。
可以使用 TensorFlow Lite Task Library 的 ObjectDetector API 将此模型集成到 Android 或 iOS 应用中。
有关如何在工作应用中使用模型的更多详细信息,请参阅 TFLite 目标检测示例应用。
注:Android Studio Model Binding 目前还不支持目标检测,请使用 TensorFlow Lite Task Library。
(可选)在您的图像上测试 TFLite 模型
您可以使用互联网上的图像测试训练后的 TFLite 模型。
- 将下面的
INPUT_IMAGE_URL
替换为所需的输入图像。 - 调整
DETECTION_THRESHOLD
以更改模型的灵敏度。较低的阈值意味着模型将拾取更多对象,但也会有更多的错误检测。与此同时,更高的阈值意味着该模型将只拾取它确信检测到的对象。
尽管目前需要一些样板代码才能在 Python 中运行模型,但将模型集成到移动应用中只需要几行代码。
Load the trained TFLite model and define some visualization functions
import cv2
from PIL import Image
model_path = 'model.tflite'
# Load the labels into a list
classes = ['???'] * model.model_spec.config.num_classes
label_map = model.model_spec.config.label_map
for label_id, label_name in label_map.as_dict().items():
classes[label_id-1] = label_name
# Define a list of colors for visualization
COLORS = np.random.randint(0, 255, size=(len(classes), 3), dtype=np.uint8)
def preprocess_image(image_path, input_size):
"""Preprocess the input image to feed to the TFLite model"""
img = tf.io.read_file(image_path)
img = tf.io.decode_image(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.uint8)
original_image = img
resized_img = tf.image.resize(img, input_size)
resized_img = resized_img[tf.newaxis, :]
resized_img = tf.cast(resized_img, dtype=tf.uint8)
return resized_img, original_image
def detect_objects(interpreter, image, threshold):
"""Returns a list of detection results, each a dictionary of object info."""
signature_fn = interpreter.get_signature_runner()
# Feed the input image to the model
output = signature_fn(images=image)
# Get all outputs from the model
count = int(np.squeeze(output['output_0']))
scores = np.squeeze(output['output_1'])
classes = np.squeeze(output['output_2'])
boxes = np.squeeze(output['output_3'])
results = []
for i in range(count):
if scores[i] >= threshold:
result = {
'bounding_box': boxes[i],
'class_id': classes[i],
'score': scores[i]
}
results.append(result)
return results
def run_odt_and_draw_results(image_path, interpreter, threshold=0.5):
"""Run object detection on the input image and draw the detection results"""
# Load the input shape required by the model
_, input_height, input_width, _ = interpreter.get_input_details()[0]['shape']
# Load the input image and preprocess it
preprocessed_image, original_image = preprocess_image(
image_path,
(input_height, input_width)
)
# Run object detection on the input image
results = detect_objects(interpreter, preprocessed_image, threshold=threshold)
# Plot the detection results on the input image
original_image_np = original_image.numpy().astype(np.uint8)
for obj in results:
# Convert the object bounding box from relative coordinates to absolute
# coordinates based on the original image resolution
ymin, xmin, ymax, xmax = obj['bounding_box']
xmin = int(xmin * original_image_np.shape[1])
xmax = int(xmax * original_image_np.shape[1])
ymin = int(ymin * original_image_np.shape[0])
ymax = int(ymax * original_image_np.shape[0])
# Find the class index of the current object
class_id = int(obj['class_id'])
# Draw the bounding box and label on the image
color = [int(c) for c in COLORS[class_id]]
cv2.rectangle(original_image_np, (xmin, ymin), (xmax, ymax), color, 2)
# Make adjustments to make the label visible for all objects
y = ymin - 15 if ymin - 15 > 15 else ymin + 15
label = "{}: {:.0f}%".format(classes[class_id], obj['score'] * 100)
cv2.putText(original_image_np, label, (xmin, y),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# Return the final image
original_uint8 = original_image_np.astype(np.uint8)
return original_uint8
Run object detection and show the detection results
INPUT_IMAGE_URL = "https://storage.googleapis.com/cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg"
DETECTION_THRESHOLD = 0.3
TEMP_FILE = '/tmp/image.png'
!wget -q -O $TEMP_FILE $INPUT_IMAGE_URL
im = Image.open(TEMP_FILE)
im.thumbnail((512, 512), Image.ANTIALIAS)
im.save(TEMP_FILE, 'PNG')
# Load the TFLite model
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
# Run inference and draw detection result on the local copy of the original file
detection_result_image = run_odt_and_draw_results(
TEMP_FILE,
interpreter,
threshold=DETECTION_THRESHOLD
)
# Show the detection result
Image.fromarray(detection_result_image)
/tmpfs/tmp/ipykernel_81842/1212634245.py:10: DeprecationWarning: ANTIALIAS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead. im.thumbnail((512, 512), Image.ANTIALIAS)
(可选)针对 Edge TPU 编译
现在您已经有了量化的 EfficientDet Lite 模型,可以编译并部署到 Coral EdgeTPU。
第 1 步:安装 EdgeTPU 编译器
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list
sudo apt-get update
sudo apt-get install edgetpu-compiler
% Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 2537 100 2537 0 0 130k 0 --:--:-- --:--:-- --:--:-- 130k OK deb https://packages.cloud.google.com/apt coral-edgetpu-stable main Hit:1 http://us-central1.gce.archive.ubuntu.com/ubuntu focal InRelease Hit:2 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates InRelease Get:3 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB] Hit:4 http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 InRelease Get:6 https://packages.cloud.google.com/apt coral-edgetpu-stable InRelease [6722 B] Hit:7 https://nvidia.github.io/libnvidia-container/stable/ubuntu18.04/amd64 InRelease Hit:8 https://download.docker.com/linux/ubuntu focal InRelease Get:9 https://nvidia.github.io/nvidia-container-runtime/stable/ubuntu18.04/amd64 InRelease [1481 B] Get:10 https://nvidia.github.io/nvidia-docker/ubuntu18.04/amd64 InRelease [1474 B] Hit:11 http://security.ubuntu.com/ubuntu focal-security InRelease Ign:12 http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 InRelease Hit:13 http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 Release Hit:5 https://apt.llvm.org/focal llvm-toolchain-focal-14 InRelease Hit:14 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease Ign:15 https://packages.cloud.google.com/apt coral-edgetpu-stable/main amd64 Packages Hit:16 http://ppa.launchpad.net/longsleep/golang-backports/ubuntu focal InRelease Hit:17 http://ppa.launchpad.net/openjdk-r/ppa/ubuntu focal InRelease Get:15 https://packages.cloud.google.com/apt coral-edgetpu-stable/main amd64 Packages [2317 B] Fetched 120 kB in 1s (82.9 kB/s) The following packages were automatically installed and are no longer required: libatasmart4 libblockdev-fs2 libblockdev-loop2 libblockdev-part-err2 libblockdev-part2 libblockdev-swap2 libblockdev-utils2 libblockdev2 libnuma1 libparted-fs-resize0 Use 'sudo apt autoremove' to remove them. The following NEW packages will be installed: edgetpu-compiler 0 upgraded, 1 newly installed, 0 to remove and 258 not upgraded. Need to get 7913 kB of archives. After this operation, 31.2 MB of additional disk space will be used. Get:1 https://packages.cloud.google.com/apt coral-edgetpu-stable/main amd64 edgetpu-compiler amd64 16.0 [7913 kB] Fetched 7913 kB in 0s (30.0 MB/s) Selecting previously unselected package edgetpu-compiler. (Reading database ... 139823 files and directories currently installed.) Preparing to unpack .../edgetpu-compiler_16.0_amd64.deb ... Unpacking edgetpu-compiler (16.0) ... Setting up edgetpu-compiler (16.0) ... Processing triggers for libc-bin (2.31-0ubuntu9.7) ...
第 2 步:选择 Edge TPU 数量,然后编译
EdgeTPU 有 8MB 的 SRAM 用于缓存模型参数(更多信息)。这意味着对于大于 8MB 的模型,为了传递模型参数,推断时间将增加。避免这种情况的一种方式是模型流水线 - 将模型拆分成可以使用专用 EdgeTPU 的段。这可以显著改善延迟。
下表可用作要使用的 Edge TPU 数量的参考 - 由于中间张量无法放入片上内存,较大的模型将无法使用单个 TPU 编译。
模型架构 | 最低 TPU 数 | 建议 TPU 数 |
---|---|---|
EfficientDet-Lite0 | 1 | 1 |
EfficientDet-Lite1 | 1 | 1 |
EfficientDet-Lite2 | 1 | 2 |
EfficientDet-Lite3 | 2 | 2 |
EfficientDet-Lite4 | 2 | 3 |
NUMBER_OF_TPUS = 1
!edgetpu_compiler model.tflite --num_segments=$NUMBER_OF_TPUS
Edge TPU Compiler version 16.0.384591198 Started a compilation timeout timer of 180 seconds. Model compiled successfully in 4407 ms. Input model: model.tflite Input size: 4.24MiB Output model: model_edgetpu.tflite Output size: 5.61MiB On-chip memory used for caching model parameters: 4.24MiB On-chip memory remaining for caching model parameters: 3.27MiB Off-chip memory used for streaming uncached model parameters: 0.00B Number of Edge TPU subgraphs: 1 Total number of operations: 267 Operation log: model_edgetpu.log Model successfully compiled but not all operations are supported by the Edge TPU. A percentage of the model will instead run on the CPU, which is slower. If possible, consider updating your model to use only operations supported by the Edge TPU. For details, visit g.co/coral/model-reqs. Number of operations that will run on Edge TPU: 264 Number of operations that will run on CPU: 3 See the operation log file for individual operation details. Compilation child process completed within timeout period. Compilation succeeded!
第 3 步:下载并运行模型
经过编译后,现在可以在 EdgeTPU 上运行模型以进行目标检测。首先,使用 Colab 的左侧边栏下载编译后的 TensorFlow Lite 模型文件。右键点击 model_edgetpu.tflite
文件,然后选择 Download
将其下载到本地计算机。
现在,您可以用您喜欢的方式运行模型。检测示例包括:
高级用法
本部分介绍高级用法主题,如调整模型和训练超参数。
加载数据集
加载您自己的数据
您可以上传您自己的数据集以完成本教程。请使用 Colab 的左侧边栏上传您的数据集。
如果您不想将数据集上传到云端,也可以按照指南在本地运行库。
使用不同的数据格式加载数据
Model Maker 库还支持 object_detector.DataLoader.from_pascal_voc
方法来加载 PASCAL VOC 格式的数据。makesense.ai 和 LabelImg 工具可以注解图像并将注解保存为 PASCAL VOC 数据格式的 XML 文件:
object_detector.DataLoader.from_pascal_voc(image_dir, annotations_dir, label_map={1: "person", 2: "notperson"})
自定义 EfficientDet 模型超参数
可以调整的模型和训练流水线参数包括:
model_dir
:模型检查点文件的保存位置。如果未设置,将使用临时目录。steps_per_execution
:每个训练执行的步骤数。moving_average_decay
浮点。用于维护训练参数的移动平均值的衰减。var_freeze_expr
:映射待冻结变量的前缀名称的正则表达式,表示在训练期间保持不变。更具体地说,在代码库中使用re.match(var_freeze_expr, variable_name)
来映射要冻结的变量。tflite_max_detections
:整数,默认为 25。TFLite 模型中的最大输出检测数。strategy
:指定使用哪种分布策略的字符串。可接受的值为 'tpu'、'gpus'、None。'tpu' 是指使用 TPUStrategy。'gpus' 是指为多 GPU 使用 MirroredStrategy。如果为 None,则使用 OneDeviceStrategy 的 TF 默认值。tpu
:用于训练的 Cloud TPU。这应该是创建 Cloud TPU 时使用的名称,或者是 grpc://ip.address.of.tpu:8470 网址。use_xla
:即使策略不是 TPU,也使用 XLA。如果策略是 TPU,则始终使用 XLA,并且此标志无效。profile
:启用配置文件模式。debug
:启用调试模式。
其他可以调整的参数如 hparams_config.py 中所示。
例如,您可以设置 var_freeze_expr='efficientnet'
,这将冻结名称前缀为 efficientnet
的变量(默认为 '(efficientnet|fpn_cells|resample_p6)'
)。这允许模型冻结不可训练的变量,并在训练过程中保持它们的值不变。
spec = model_spec.get('efficientdet_lite0')
spec.config.var_freeze_expr = 'efficientnet'
更改模型架构
您可以通过更改 model_spec
来更改模型架构。例如,将 model_spec
更改为 EfficientDet-Lite4 模型。
spec = model_spec.get('efficientdet_lite4')
调整训练超参数
create
函数是 Model Maker 库用于创建模型的驱动函数。model_spec
参数定义模型规范。目前支持 object_detector.EfficientDetSpec
类。create
函数包括以下步骤:
- 根据
model_spec
创建用于目标检测的模型。 - 训练模型。默认周期和默认批次大小由
model_spec
对象中的epochs
和batch_size
变量设置。您还可以调整训练超参数,如影响模型准确率的epochs
和batch_size
。例如,
epochs
:整数,默认为 50。更多周期可以获得更好的准确率,但可以会导致过拟合。batch_size
:整数,默认为 64。一个训练步骤中要使用的样本数。train_whole_model
:布尔值,默认为 False。如果为 true,则训练整个模型。否则,只训练不匹配var_freeze_expr
的层。
例如,您可以使用较少的周期进行训练,并且只使用头层。您可以增加周期数以获得更好的效果。
model = object_detector.create(train_data, model_spec=spec, epochs=10, validation_data=validation_data)
导出为不同格式
导出格式可以是以下列表中的一个或多个:
默认情况下,它仅导出包含模型元数据的 TensorFlow Lite 模型文件,以便以后在设备端机器学习应用中使用。标签文件嵌入在元数据中。
在许多设备端机器学习应用中,模型大小是一个重要因素。因此,建议您量化模型以使其更小并可能加快运行速度。对于 EfficientDet-Lite 模型,默认使用全整数量化来量化模型。请参阅训练后量化了解详细信息。
model.export(export_dir='.')
您还可以选择导出与模型相关的其他文件,以便更好地进行检查。例如,按如下方式同时导出保存的模型和标签文件:
model.export(export_dir='.', export_format=[ExportFormat.SAVED_MODEL, ExportFormat.LABEL])
在 TensorFlow Lite 模型上自定义训练后量化
训练后量化是一种转换技术,可以缩减模型大小并缩短推断延迟,同时改善 CPU 和硬件加速器推断速度,且几乎不会降低模型准确率。因此,它被广泛用于优化模型。
Model Maker 库在导出模型时会应用默认的训练后量化技术。如果您想自定义训练后量化,Model Maker 也支持使用 QuantizationConfig 的多个训练后量化选项。我们以 float16 量化为例。首先,定义量化配置。
config = QuantizationConfig.for_float16()
然后,我们使用此配置导出 TensorFlow Lite 模型。
model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)
阅读更多
您可以阅读我们的目标检测示例以了解技术细节。如需了解更多信息,请参阅:
- TensorFlow Lite Model Maker 指南和 API 参考。
- Task Library:用于部署的 ObjectDetector。
- 端到端参考应用: Android、 iOS 和 Raspberry PI。