Post-training dynamic range quantization

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook See TF Hub model

Overview

TensorFlow Lite now supports converting weights to 8 bit precision as part of model conversion from tensorflow graphdefs to TensorFlow Lite's flat buffer format. Dynamic range quantization achieves a 4x reduction in the model size. In addition, TFLite supports on the fly quantization and dequantization of activations to allow for:

  1. Using quantized kernels for faster implementation when available.
  2. Mixing of floating-point kernels with quantized kernels for different parts of the graph.

The activations are always stored in floating point. For ops that support quantized kernels, the activations are quantized to 8 bits of precision dynamically prior to processing and are de-quantized to float precision after processing. Depending on the model being converted, this can give a speedup over pure floating point computation.

In contrast to quantization aware training , the weights are quantized post training and the activations are quantized dynamically at inference in this method. Therefore, the model weights are not retrained to compensate for quantization induced errors. It is important to check the accuracy of the quantized model to ensure that the degradation is acceptable.

This tutorial trains an MNIST model from scratch, checks its accuracy in TensorFlow, and then converts the model into a Tensorflow Lite flatbuffer with dynamic range quantization. Finally, it checks the accuracy of the converted model and compare it to the original float model.

Build an MNIST model

Setup

import logging
logging.getLogger("tensorflow").setLevel(logging.DEBUG)

import tensorflow as tf
from tensorflow import keras
import numpy as np
import pathlib
2022-10-20 13:35:57.099725: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-10-20 13:35:57.099843: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-10-20 13:35:57.099853: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Train a TensorFlow model

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_data=(test_images, test_labels)
)
1875/1875 [==============================] - 7s 3ms/step - loss: 0.2691 - accuracy: 0.9257 - val_loss: 0.1276 - val_accuracy: 0.9641
<keras.callbacks.History at 0x7f9d61d63550>

For the example, since you trained the model for just a single epoch, so it only trains to ~96% accuracy.

Convert to a TensorFlow Lite model

Using the TensorFlow Lite Converter, you can now convert the trained model into a TensorFlow Lite model.

Now load the model using the TFLiteConverter:

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpss58mksg/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpss58mksg/assets
2022-10-20 13:36:10.996215: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-10-20 13:36:10.996253: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.

Write it out to a tflite file:

tflite_models_dir = pathlib.Path("/tmp/mnist_tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)
tflite_model_file = tflite_models_dir/"mnist_model.tflite"
tflite_model_file.write_bytes(tflite_model)
84824

To quantize the model on export, set the optimizations flag to optimize for size:

converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
tflite_model_quant_file = tflite_models_dir/"mnist_model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_quant_model)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpuux_rvj7/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpuux_rvj7/assets
2022-10-20 13:36:11.913375: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-10-20 13:36:11.913410: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
24072

Note how the resulting file, is approximately 1/4 the size.

ls -lh {tflite_models_dir}
total 136K
-rw-rw-r-- 1 kbuilder kbuilder 83K Oct 20 13:36 mnist_model.tflite
-rw-rw-r-- 1 kbuilder kbuilder 24K Oct 20 13:36 mnist_model_quant.tflite
-rw-rw-r-- 1 kbuilder kbuilder 25K Oct 20 13:27 mnist_model_quant_16x8.tflite

Run the TFLite models

Run the TensorFlow Lite model using the Python TensorFlow Lite Interpreter.

Load the model into an interpreter

interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
interpreter.allocate_tensors()
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
interpreter_quant = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))
interpreter_quant.allocate_tensors()

Test the model on one image

test_image = np.expand_dims(test_images[0], axis=0).astype(np.float32)

input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]

interpreter.set_tensor(input_index, test_image)
interpreter.invoke()
predictions = interpreter.get_tensor(output_index)
import matplotlib.pylab as plt

plt.imshow(test_images[0])
template = "True:{true}, predicted:{predict}"
_ = plt.title(template.format(true= str(test_labels[0]),
                              predict=str(np.argmax(predictions[0]))))
plt.grid(False)

png

Evaluate the models

# A helper function to evaluate the TF Lite model using "test" dataset.
def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for test_image in test_images:
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  # Compare prediction results with ground truth labels to calculate accuracy.
  accurate_count = 0
  for index in range(len(prediction_digits)):
    if prediction_digits[index] == test_labels[index]:
      accurate_count += 1
  accuracy = accurate_count * 1.0 / len(prediction_digits)

  return accuracy
print(evaluate_model(interpreter))
0.9641

Repeat the evaluation on the dynamic range quantized model to obtain:

print(evaluate_model(interpreter_quant))
0.9637

In this example, the compressed model has no difference in the accuracy.

Optimizing an existing model

Resnets with pre-activation layers (Resnet-v2) are widely used for vision applications. Pre-trained frozen graph for resnet-v2-101 is available on Tensorflow Hub.

You can convert the frozen graph to a TensorFLow Lite flatbuffer with quantization by:

import tensorflow_hub as hub

resnet_v2_101 = tf.keras.Sequential([
  keras.layers.InputLayer(input_shape=(224, 224, 3)),
  hub.KerasLayer("https://tfhub.dev/google/imagenet/resnet_v2_101/classification/4")
])

converter = tf.lite.TFLiteConverter.from_keras_model(resnet_v2_101)
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
# Convert to TF Lite without quantization
resnet_tflite_file = tflite_models_dir/"resnet_v2_101.tflite"
resnet_tflite_file.write_bytes(converter.convert())
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpzoal5hje/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpzoal5hje/assets
2022-10-20 13:36:40.808550: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-10-20 13:36:40.808602: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
178422332
# Convert to TF Lite with quantization
converter.optimizations = [tf.lite.Optimize.DEFAULT]
resnet_quantized_tflite_file = tflite_models_dir/"resnet_v2_101_quantized.tflite"
resnet_quantized_tflite_file.write_bytes(converter.convert())
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp_exld338/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp_exld338/assets
2022-10-20 13:37:06.805582: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-10-20 13:37:06.805640: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
45733976
ls -lh {tflite_models_dir}/*.tflite
-rw-rw-r-- 1 kbuilder kbuilder  83K Oct 20 13:36 /tmp/mnist_tflite_models/mnist_model.tflite
-rw-rw-r-- 1 kbuilder kbuilder  24K Oct 20 13:36 /tmp/mnist_tflite_models/mnist_model_quant.tflite
-rw-rw-r-- 1 kbuilder kbuilder  25K Oct 20 13:27 /tmp/mnist_tflite_models/mnist_model_quant_16x8.tflite
-rw-rw-r-- 1 kbuilder kbuilder 171M Oct 20 13:36 /tmp/mnist_tflite_models/resnet_v2_101.tflite
-rw-rw-r-- 1 kbuilder kbuilder  44M Oct 20 13:37 /tmp/mnist_tflite_models/resnet_v2_101_quantized.tflite

The model size reduces from 171 MB to 43 MB. The accuracy of this model on imagenet can be evaluated using the scripts provided for TFLite accuracy measurement.

The optimized model top-1 accuracy is 76.8, the same as the floating point model.