int16 활성화를 사용한 훈련 후 정수 양자화

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

개요

TensorFlow Lite는 이제 TensorFlow에서 TensorFlow Lite의 플랫 버퍼 형식으로 모델을 변환하는 동안 활성화를 16bit 정수 값으로, 가중치를 8bit 정수 값으로 변환하는 작업을 지원합니다. 이 모드를 "16x8 양자화 모드"라고 합니다. 이 모드는 활성화가 양자화에 민감할 때 양자화된 모델의 정확성을 크게 향상시키면서도 모델 크기를 거의 3 ~ 4배 줄일 수 있습니다. 또한 이 완전 양자화된 모델은 정수 전용 하드웨어 가속기에서 사용할 수 있습니다.

이 훈련 후 양자화 모드의 이점을 얻는 모델의 몇 가지 예는 다음과 같습니다.

  • 초고해상도
  • 소음 상쇄 및 빔포밍과 같은 오디오 신호 처리
  • 이미지 노이즈 제거
  • 단일 이미지에서 HDR 재구성

이 튜토리얼에서는 MNIST 모델을 처음부터 학습시키고 TensorFlow에서 정확성을 확인한 다음, 이 모드를 사용하여 모델을 Tensorflow Lite 플랫 버퍼로 변환합니다. 마지막으로, 변환된 모델의 정확성을 확인하고 원래 float32 모델과 비교합니다. 이 예제는 이 모드의 사용법을 보여주는 것이며 TensorFlow Lite에서 사용 가능한 다른 양자화 기술과 비교한 이점을 보여주지는 않습니다.

MNIST 모델 빌드하기

설정

import logging
logging.getLogger("tensorflow").setLevel(logging.DEBUG)

import tensorflow as tf
from tensorflow import keras
import numpy as np
import pathlib
2022-12-15 00:58:13.702544: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-15 00:58:13.702639: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-15 00:58:13.702648: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

16x8 양자화 모드를 사용할 수 있는지 확인합니다.

tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
<OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8: 'EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8'>

모델 훈련 및 내보내기

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_data=(test_images, test_labels)
)
1875/1875 [==============================] - 8s 3ms/step - loss: 0.3165 - accuracy: 0.9095 - val_loss: 0.1584 - val_accuracy: 0.9553
<keras.callbacks.History at 0x7f7b3fd1fd00>

예를 들어, 단일 epoch에 대해서만 모델을 훈련시켰으므로 ~96% 정확성으로만 훈련합니다.

TensorFlow Lite 모델로 변환하기

TensorFlow Lite Converter를 사용하여 이제 훈련된 모델을 TensorFlow Lite 모델로 변환할 수 있습니다.

이제 TFliteConverter를 사용하여 모델을 기본 float32 형식으로 변환합니다.

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmplvt0uukx/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmplvt0uukx/assets
2022-12-15 00:58:28.021632: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-15 00:58:28.021671: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.

.tflite 파일에 작성합니다.

tflite_models_dir = pathlib.Path("/tmp/mnist_tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)
tflite_model_file = tflite_models_dir/"mnist_model.tflite"
tflite_model_file.write_bytes(tflite_model)
84820

대신 모델을 16x8 양자화 모드로 양자화하려면 먼저 기본 최적화를 사용하도록 optimizations 플래그를 설정합니다. 그런 다음 16x8 양자화 모드가 대상 사양에서 지원되는 필수 연산임을 지정합니다.

converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8]

int8 훈련 후 양자화의 경우와 마찬가지로, 변환기 옵션 inference_input(output)_type을 tf.int16으로 설정하여 완전히 정수 양자화된 모델을 생성할 수 있습니다.

보정 데이터를 설정합니다.

mnist_train, _ = tf.keras.datasets.mnist.load_data()
images = tf.cast(mnist_train[0], tf.float32) / 255.0
mnist_ds = tf.data.Dataset.from_tensor_slices((images)).batch(1)
def representative_data_gen():
  for input_value in mnist_ds.take(100):
    # Model has only one input so each data point has one element.
    yield [input_value]
converter.representative_dataset = representative_data_gen

마지막으로, 평소와 같이 모델을 변환합니다. 기본적으로 변환된 모델은 호출 편의를 위해 여전히 float 입력 및 출력을 사용합니다.

tflite_16x8_model = converter.convert()
tflite_model_16x8_file = tflite_models_dir/"mnist_model_quant_16x8.tflite"
tflite_model_16x8_file.write_bytes(tflite_16x8_model)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpiqja0kzc/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpiqja0kzc/assets
2022-12-15 00:58:29.502394: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-15 00:58:29.502436: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
24960

결과 파일이 약 1/3 크기인 것에 주목합니다.

ls -lh {tflite_models_dir}
total 156K
-rw-rw-r-- 1 kbuilder kbuilder 83K Dec 15 00:58 mnist_model.tflite
-rw-rw-r-- 1 kbuilder kbuilder 25K Dec 15 00:58 mnist_model_quant_16x8.tflite
-rw-rw-r-- 1 kbuilder kbuilder 44K Dec 15 00:56 mnist_model_quant_f16.tflite

TensorFlow Lite 모델 실행하기

Python TensorFlow Lite 인터프리터를 사용하여 TensorFlow Lite 모델을 실행합니다.

인터프리터에 모델 로드하기

interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
interpreter.allocate_tensors()
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
interpreter_16x8 = tf.lite.Interpreter(model_path=str(tflite_model_16x8_file))
interpreter_16x8.allocate_tensors()

하나의 이미지에서 모델 테스트하기

test_image = np.expand_dims(test_images[0], axis=0).astype(np.float32)

input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]

interpreter.set_tensor(input_index, test_image)
interpreter.invoke()
predictions = interpreter.get_tensor(output_index)
import matplotlib.pylab as plt

plt.imshow(test_images[0])
template = "True:{true}, predicted:{predict}"
_ = plt.title(template.format(true= str(test_labels[0]),
                              predict=str(np.argmax(predictions[0]))))
plt.grid(False)

png

test_image = np.expand_dims(test_images[0], axis=0).astype(np.float32)

input_index = interpreter_16x8.get_input_details()[0]["index"]
output_index = interpreter_16x8.get_output_details()[0]["index"]

interpreter_16x8.set_tensor(input_index, test_image)
interpreter_16x8.invoke()
predictions = interpreter_16x8.get_tensor(output_index)
plt.imshow(test_images[0])
template = "True:{true}, predicted:{predict}"
_ = plt.title(template.format(true= str(test_labels[0]),
                              predict=str(np.argmax(predictions[0]))))
plt.grid(False)

png

모델 평가하기

# A helper function to evaluate the TF Lite model using "test" dataset.
def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for test_image in test_images:
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  # Compare prediction results with ground truth labels to calculate accuracy.
  accurate_count = 0
  for index in range(len(prediction_digits)):
    if prediction_digits[index] == test_labels[index]:
      accurate_count += 1
  accuracy = accurate_count * 1.0 / len(prediction_digits)

  return accuracy
print(evaluate_model(interpreter))
0.9553

16x8 양자화된 모델에 대해 평가를 반복합니다.

# NOTE: This quantization mode is an experimental post-training mode,
# it does not have any optimized kernels implementations or
# specialized machine learning hardware accelerators. Therefore,
# it could be slower than the float interpreter.
print(evaluate_model(interpreter_16x8))
0.9554

이 예에서는 정확성에 차이가 없지만 3배 축소된 크기로 모델을 16x8로 양자화했습니다.