Este documento explica cómo entrenar un modelo y ejecutar inferencias utilizando un microcontrolador.
El ejemplo de Hola Mundo
El ejemplo de Hello World está diseñado para demostrar los conceptos básicos absolutos del uso de TensorFlow Lite para microcontroladores. Entrenamos y ejecutamos un modelo que replica una función seno, es decir, toma un solo número como entrada y genera el valor seno del número. Cuando se implementa en el microcontrolador, sus predicciones se utilizan para hacer parpadear los LED o controlar una animación.
El flujo de trabajo de un extremo a otro implica los siguientes pasos:
- Entrenar un modelo (en Python): un archivo de Python para entrenar, convertir y optimizar un modelo para su uso en el dispositivo.
- Ejecutar inferencia (en C++ 17): una prueba unitaria de un extremo a otro que ejecuta inferencia en el modelo utilizando la biblioteca C++ .
Obtenga un dispositivo compatible
La aplicación de ejemplo que usaremos ha sido probada en los siguientes dispositivos:
- Arduino Nano 33 BLE Sense (usando Arduino IDE)
- SparkFun Edge (construido directamente desde la fuente)
- Kit de descubrimiento STM32F746 (usando Mbed)
- Adafruit EdgeBadge (usando Arduino IDE)
- Kit Adafruit TensorFlow Lite para microcontroladores (usando Arduino IDE)
- Adafruit Circuit Playground Bluefruit (usando Arduino IDE)
- Espressif ESP32-DevKitC (usando ESP IDF)
- Espressif ESP-EYE (usando ESP IDF)
Obtenga más información sobre las plataformas compatibles en TensorFlow Lite para microcontroladores .
Entrenar un modelo
Utilice train.py para entrenar el modelo hola mundo para el reconocimiento de ondas sinusoidales
Ejecutar: bazel build tensorflow/lite/micro/examples/hello_world:train
bazel-bin/tensorflow/lite/micro/examples/hello_world/train --save_tf_model --save_dir=/tmp/model_created/
Ejecutar inferencia
Para ejecutar el modelo en su dispositivo, seguiremos las instrucciones en README.md
:
Las siguientes secciones recorren la prueba unitaria evaluate_test.cc
del ejemplo que demuestra cómo ejecutar la inferencia usando TensorFlow Lite para microcontroladores. Carga el modelo y ejecuta la inferencia varias veces.
1. Incluya los encabezados de la biblioteca.
Para utilizar la biblioteca TensorFlow Lite para microcontroladores, debemos incluir los siguientes archivos de encabezado:
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
-
micro_mutable_op_resolver.h
proporciona las operaciones utilizadas por el intérprete para ejecutar el modelo. -
micro_error_reporter.h
genera información de depuración. -
micro_interpreter.h
contiene código para cargar y ejecutar modelos. -
schema_generated.h
contiene el esquema para el formato de archivo del modeloFlatBuffer
TensorFlow Lite. -
version.h
proporciona información de versiones para el esquema de TensorFlow Lite.
2. Incluya el encabezado del modelo.
El intérprete de TensorFlow Lite para microcontroladores espera que el modelo se proporcione como una matriz C++. El modelo se define en los archivos model.h
y model.cc
. El encabezado se incluye con la siguiente línea:
#include "tensorflow/lite/micro/examples/hello_world/model.h"
3. Incluya el encabezado del marco de prueba unitaria.
Para crear una prueba unitaria, incluimos el marco de prueba unitaria de TensorFlow Lite para microcontroladores incluyendo la siguiente línea:
#include "tensorflow/lite/micro/testing/micro_test.h"
La prueba se define utilizando las siguientes macros:
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(LoadModelAndPerformInference) {
. // add code here
.
}
TF_LITE_MICRO_TESTS_END
Ahora analizamos el código incluido en la macro anterior.
4. Configurar el registro
Para configurar el registro, se crea un puntero tflite::ErrorReporter
usando un puntero a una instancia tflite::MicroErrorReporter
:
tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = µ_error_reporter;
Esta variable se pasará al intérprete, lo que le permite escribir registros. Dado que los microcontroladores suelen tener una variedad de mecanismos para iniciar sesión, la implementación de tflite::MicroErrorReporter
está diseñada para personalizarse para su dispositivo en particular.
5. Cargar un modelo
En el siguiente código, se crea una instancia del modelo utilizando datos de una matriz char
, g_model
, que se declara en model.h
. Luego verificamos el modelo para asegurarnos de que su versión de esquema sea compatible con la versión que estamos usando:
const tflite::Model* model = ::tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.\n",
model->version(), TFLITE_SCHEMA_VERSION);
}
6. Crear instancias de resolución de operaciones
Se declara una instancia MicroMutableOpResolver
. Esto será utilizado por el intérprete para registrar y acceder a las operaciones que utiliza el modelo:
using HelloWorldOpResolver = tflite::MicroMutableOpResolver<1>;
TfLiteStatus RegisterOps(HelloWorldOpResolver& op_resolver) {
TF_LITE_ENSURE_STATUS(op_resolver.AddFullyConnected());
return kTfLiteOk;
MicroMutableOpResolver
requiere un parámetro de plantilla que indique la cantidad de operaciones que se registrarán. La función RegisterOps
registra las operaciones con el solucionador.
HelloWorldOpResolver op_resolver;
TF_LITE_ENSURE_STATUS(RegisterOps(op_resolver));
7. Asignar memoria
Necesitamos preasignar una cierta cantidad de memoria para matrices de entrada, salida y intermedias. Esto se proporciona como una matriz uint8_t
de tamaño tensor_arena_size
:
const int tensor_arena_size = 2 * 1024;
uint8_t tensor_arena[tensor_arena_size];
El tamaño requerido dependerá del modelo que esté utilizando y es posible que deba determinarse mediante experimentación.
8. Intérprete instanciado
Creamos una instancia tflite::MicroInterpreter
, pasando las variables creadas anteriormente:
tflite::MicroInterpreter interpreter(model, resolver, tensor_arena,
tensor_arena_size, error_reporter);
9. Asignar tensores
Le decimos al intérprete que asigne memoria del tensor_arena
para los tensores del modelo:
interpreter.AllocateTensors();
10. Validar la forma de entrada
La instancia MicroInterpreter
puede proporcionarnos un puntero al tensor de entrada del modelo llamando .input(0)
, donde 0
representa el primer (y único) tensor de entrada:
// Obtain a pointer to the model's input tensor
TfLiteTensor* input = interpreter.input(0);
Luego inspeccionamos este tensor para confirmar que su forma y tipo son los que esperamos:
// Make sure the input has the properties we expect
TF_LITE_MICRO_EXPECT_NE(nullptr, input);
// The property "dims" tells us the tensor's shape. It has one element for
// each dimension. Our input is a 2D tensor containing 1 element, so "dims"
// should have size 2.
TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size);
// The value of each element gives the length of the corresponding tensor.
// We should expect two single element tensors (one is contained within the
// other).
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
// The input is a 32 bit floating point value
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, input->type);
El valor de enumeración kTfLiteFloat32
es una referencia a uno de los tipos de datos de TensorFlow Lite y está definido en common.h
.
11. Proporcione un valor de entrada
Para proporcionar una entrada al modelo, configuramos el contenido del tensor de entrada de la siguiente manera:
input->data.f[0] = 0.;
En este caso, ingresamos un valor de punto flotante que representa 0
.
12. Ejecute el modelo
Para ejecutar el modelo, podemos llamar Invoke()
en nuestra instancia tflite::MicroInterpreter
:
TfLiteStatus invoke_status = interpreter.Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed\n");
}
Podemos verificar el valor de retorno, TfLiteStatus
, para determinar si la ejecución fue exitosa. Los valores posibles de TfLiteStatus
, definidos en common.h
, son kTfLiteOk
y kTfLiteError
.
El siguiente código afirma que el valor es kTfLiteOk
, lo que significa que la inferencia se ejecutó correctamente.
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
13. Obtener el resultado
El tensor de salida del modelo se puede obtener llamando output(0)
en tflite::MicroInterpreter
, donde 0
representa el primer (y único) tensor de salida.
En el ejemplo, la salida del modelo es un único valor de coma flotante contenido dentro de un tensor 2D:
TfLiteTensor* output = interpreter.output(0);
TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, output->type);
Podemos leer el valor directamente del tensor de salida y afirmar que es lo que esperamos:
// Obtain the output value from the tensor
float value = output->data.f[0];
// Check that the output value is within 0.05 of the expected value
TF_LITE_MICRO_EXPECT_NEAR(0., value, 0.05);
14. Ejecute la inferencia nuevamente
El resto del código ejecuta la inferencia varias veces más. En cada caso, asignamos un valor al tensor de entrada, invocamos al intérprete y leemos el resultado del tensor de salida:
input->data.f[0] = 1.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.841, value, 0.05);
input->data.f[0] = 3.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.141, value, 0.05);
input->data.f[0] = 5.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(-0.959, value, 0.05);