训练后整数量化

在 TensorFlow.org 查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

概述

整数量化是一种优化策略,可将 32 位浮点数(如权重和激活输出)转换为 8 位定点数。这样可以缩减模型大小并加快推理速度,这对低功耗设备(如微控制器)很有价值。仅支持整数的加速器(如 Edge TPU)也需要使用此数据格式。

在本教程中,您将从头开始训练一个 MNIST 模型、将其转换为 TensorFlow Lite 文件,并使用训练后量化对其进行量化。最后,您将检查转换后模型的准确率并将其与原始浮点模型进行比较。

实际上,对模型进行量化的程度有几种选项。在本教程中,您将执行“全整数量化”,它会将所有权重和激活输出转换为 8 位整数数据,而其他策略可能会将部分数据保留为浮点。

要详细了解各种量化策略,请阅读 TensorFlow Lite 模型优化

设置

为了量化输入和输出张量,我们需要使用 TensorFlow 2.3 中新添加的 API:

import logging
logging.getLogger("tensorflow").setLevel(logging.DEBUG)

import tensorflow as tf
import numpy as np
print("TensorFlow version: ", tf.__version__)
2022-08-11 19:07:46.234983: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-11 19:07:47.061607: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-11 19:07:47.061943: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-11 19:07:47.061957: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
TensorFlow version:  2.10.0-rc0

生成 TensorFlow 模型

我们将构建一个简单的模型来对 MNIST 数据集中的数字进行分类。

此训练不会花很长时间,因为只对模型进行 5 个周期的训练,训练到约 98% 的准确率。

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images.astype(np.float32) / 255.0
test_images = test_images.astype(np.float32) / 255.0

# Define the model architecture
model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(
                  from_logits=True),
              metrics=['accuracy'])
model.fit(
  train_images,
  train_labels,
  epochs=5,
  validation_data=(test_images, test_labels)
)
Epoch 1/5
1875/1875 [==============================] - 7s 3ms/step - loss: 0.2688 - accuracy: 0.9261 - val_loss: 0.1208 - val_accuracy: 0.9632
Epoch 2/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.1062 - accuracy: 0.9701 - val_loss: 0.0870 - val_accuracy: 0.9746
Epoch 3/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.0805 - accuracy: 0.9766 - val_loss: 0.0746 - val_accuracy: 0.9770
Epoch 4/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.0668 - accuracy: 0.9801 - val_loss: 0.0683 - val_accuracy: 0.9783
Epoch 5/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.0581 - accuracy: 0.9832 - val_loss: 0.0679 - val_accuracy: 0.9775
<keras.callbacks.History at 0x7fc9011f51c0>

转换为 TensorFlow Lite 模型

现在,您可以使用TensorFlow Lite Converter 将训练后的模型转换为 TensorFlow Lite 格式,并应用不同程度的量化。

请注意,某些版本的量化会将部分数据保留为浮点格式。因此,以下各个部分将以量化程度不断增加的顺序展示每个选项,直到获得完全由 int8 或 uint8 数据组成的模型。(请注意,我们在每个部分中重复了一些代码,使您能够看到每个选项的全部量化步骤。)

首先,下面是一个没有量化的转换后模型:

converter = tf.lite.TFLiteConverter.from_keras_model(model)

tflite_model = converter.convert()
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpkjury1v_/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpkjury1v_/assets
2022-08-11 19:08:22.226815: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-08-11 19:08:22.226853: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.

它现在是一个 TensorFlow Lite 模型,但所有参数数据仍使用 32 位浮点值。

使用动态范围量化进行转换

现在,我们启用默认的 optimizations 标记来量化所有固定参数(例如权重):

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

tflite_model_quant = converter.convert()
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpjt4keerb/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpjt4keerb/assets
2022-08-11 19:08:23.373341: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-08-11 19:08:23.373380: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.

现在,进行了权重量化的模型要略小一些,但其他变量数据仍为浮点格式。

使用浮点回退量化进行转换

要量化可变数据(例如模型输入/输出和层之间的中间体),您需要提供 RepresentativeDataset。这是一个生成器函数,它提供一组足够大的输入数据来代表典型值。转换器可以通过该函数估算所有可变数据的动态范围。(相比训练或评估数据集,此数据集不必唯一。)为了支持多个输入,每个代表性数据点都是一个列表,并且列表中的元素会根据其索引被馈送到模型。

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    # Model has only one input so each data point has one element.
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen

tflite_model_quant = converter.convert()
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqs29u3y8/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqs29u3y8/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:766: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn("Statistics for quantized inputs were expected, but not "
2022-08-11 19:08:24.526139: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-08-11 19:08:24.526176: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32

现在,所有权重和可变数据都已量化,并且与原始 TensorFlow Lite 模型相比,该模型要小得多。

但是,为了与传统上使用浮点模型输入和输出张量的应用保持兼容,TensorFlow Lite 转换器将模型的输入和输出张量保留为浮点:

interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)
input:  <class 'numpy.float32'>
output:  <class 'numpy.float32'>

这通常对兼容性有利,但它无法兼容执行全整数运算的设备(如 Edge TPU)。

此外,如果 TensorFlow Lite 不包括某个运算的量化实现,则上述过程可能会将该运算保留为浮点格式。您仍能通过此策略完成转换,并得到一个更小、更高效的模型,但它还是不兼容仅支持整数的硬件。(此 MNIST 模型中的所有算子都有量化的实现。)

因此,为了确保端到端全整数模型,您还需要几个参数…

使用仅整数量化进行转换

为了量化输入和输出张量,并让转换器在遇到无法量化的运算时引发错误,使用一些附加参数再次转换模型:

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_model_quant = converter.convert()
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphnqo4e52/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphnqo4e52/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:766: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn("Statistics for quantized inputs were expected, but not "
2022-08-11 19:08:26.493820: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-08-11 19:08:26.493859: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
fully_quantize: 0, inference_type: 6, input_inference_type: UINT8, output_inference_type: UINT8

内部量化与上文相同,但您可以看到输入和输出张量现在是整数格式:

interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)
input:  <class 'numpy.uint8'>
output:  <class 'numpy.uint8'>

现在,您有了一个整数量化模型,该模型使用整数数据作为模型的输入和输出张量,因此它兼容仅支持整数的硬件(如 Edge TPU)。

将模型另存为文件

您需要 .tflite 文件才能在其他设备上部署模型。因此,我们将转换的模型保存为文件,然后在下面运行推断时加载它们。

import pathlib

tflite_models_dir = pathlib.Path("/tmp/mnist_tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)

# Save the unquantized/float model:
tflite_model_file = tflite_models_dir/"mnist_model.tflite"
tflite_model_file.write_bytes(tflite_model)
# Save the quantized model:
tflite_model_quant_file = tflite_models_dir/"mnist_model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_model_quant)
24608

运行 TensorFlow Lite 模型

现在,我们使用 TensorFlow Lite Interpreter 运行推断来比较模型的准确率。

首先,我们需要一个函数,该函数使用给定的模型和图像运行推断,然后返回预测值:

# Helper function to run inference on a TFLite model
def run_tflite_model(tflite_file, test_image_indices):
  global test_images

  # Initialize the interpreter
  interpreter = tf.lite.Interpreter(model_path=str(tflite_file))
  interpreter.allocate_tensors()

  input_details = interpreter.get_input_details()[0]
  output_details = interpreter.get_output_details()[0]

  predictions = np.zeros((len(test_image_indices),), dtype=int)
  for i, test_image_index in enumerate(test_image_indices):
    test_image = test_images[test_image_index]
    test_label = test_labels[test_image_index]

    # Check if the input type is quantized, then rescale input data to uint8
    if input_details['dtype'] == np.uint8:
      input_scale, input_zero_point = input_details["quantization"]
      test_image = test_image / input_scale + input_zero_point

    test_image = np.expand_dims(test_image, axis=0).astype(input_details["dtype"])
    interpreter.set_tensor(input_details["index"], test_image)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details["index"])[0]

    predictions[i] = output.argmax()

  return predictions

在单个图像上测试模型

现在,我们来比较一下浮点模型和量化模型的性能:

  • tflite_model_file 是使用浮点数据的原始 TensorFlow Lite 模型。
  • tflite_model_quant_file 是我们使用全整数量化转换的上一个模型(它使用 uint8 数据作为输入和输出)。

我们来创建另一个函数打印预测值:

import matplotlib.pylab as plt

# Change this to test a different image
test_image_index = 1

## Helper function to test the models on one image
def test_model(tflite_file, test_image_index, model_type):
  global test_labels

  predictions = run_tflite_model(tflite_file, [test_image_index])

  plt.imshow(test_images[test_image_index])
  template = model_type + " Model \n True:{true}, Predicted:{predict}"
  _ = plt.title(template.format(true= str(test_labels[test_image_index]), predict=str(predictions[0])))
  plt.grid(False)

现在测试浮点模型:

test_model(tflite_model_file, test_image_index, model_type="Float")
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.

png

然后测试量化模型:

test_model(tflite_model_quant_file, test_image_index, model_type="Quantized")

png

在所有图像上评估模型

现在,我们使用在本教程开始时加载的所有测试图像来运行两个模型:

# Helper function to evaluate a TFLite model on all images
def evaluate_model(tflite_file, model_type):
  global test_images
  global test_labels

  test_image_indices = range(test_images.shape[0])
  predictions = run_tflite_model(tflite_file, test_image_indices)

  accuracy = (np.sum(test_labels== predictions) * 100) / len(test_images)

  print('%s model accuracy is %.4f%% (Number of test samples=%d)' % (
      model_type, accuracy, len(test_images)))

评估浮点模型:

evaluate_model(tflite_model_file, model_type="Float")
Float model accuracy is 97.7500% (Number of test samples=10000)

评估量化模型:

evaluate_model(tflite_model_quant_file, model_type="Quantized")
Quantized model accuracy is 97.7000% (Number of test samples=10000)

现在您有了一个整数量化模型,该模型的准确率与浮点模型相比几乎没有差别。

要详细了解其他量化策略,请阅读 TensorFlow Lite 模型优化