TFLite 코드를 TF2로 마이그레이션하기

TensorFlow.org에서보기 Google Colab에서 실행하기 GitHub에서 소스 보기 노트북 다운로드하기

TensorFlow Lite(TFLite)는 개발자가 장치(모바일, 임베디드, IoT 장치)에서 ML 추론을 실행하도록 돕는 도구 세트입니다. TFLite 변환기는 기존 TF 모델을 장치에서 효율적으로 실행할 수 있는 최적화된 TFLite 모델 형식으로 변환하는 도구입니다.

이 문서에서는 TF를 TFLite로 변환하기 위해 변경해야 하는 변환 코드에 대해 배우고 동일한 작업을 수행하는 몇 가지 예제를 살펴봅니다.

TF를 TFLite로 변환하기 위해 변경해야 하는 변환 코드

  • 레거시 TF1 모델 형식(Keras 파일, 고정된 GraphDef, 체크포인트, tf.Session 등)을 사용하는 경우 TF1/TF2 SavedModel로 업데이트하고 TF2 변환기 API tf.lite.TFLiteConverter.from_saved_model(...)를 사용하여 TFLite 모델로 변환합니다(표 1 참조).

  • 변환기 API 플래그를 업데이트합니다(표 2 참조).

  • tflite.constants와 같은 레거시 API를 제거합니다(예: tf.lite.constants.INT8tf.int8로 교체).

// 표 1 // TFLite Python 변환기 API 업데이트

TF1 API TF2 API
tf.lite.TFLiteConverter.from_saved_model('saved_model/',..) 지원됨
tf.lite.TFLiteConverter.from_keras_model_file('model.h5',..) 제거됨(SavedModel 형식으로 업데이트)
tf.lite.TFLiteConverter.from_frozen_graph('model.pb',..) 제거됨(SavedModel 형식으로 업데이트)
tf.lite.TFLiteConverter.from_session(sess,...) 제거됨(SavedModel 형식으로 업데이트)

<style> .table {margin-left: 0 !important;} </style>

// 표 2 // TFLite Python 변환기 API 플래그 업데이트

TF1 API TF2 API
allow_custom_ops
optimizations
representative_dataset
target_spec
inference_input_type
inference_output_type
experimental_new_converter
experimental_new_quantizer
지원됨







input_tensors
output_tensors
input_arrays_with_shape
output_arrays
experimental_debug_info_func
제거됨(지원되지 않는 변환기 API 인수)




change_concat_input_ranges
default_ranges_stats
get_input_arrays()
inference_type
quantized_input_stats
reorder_across_fake_quant
제거됨(지원되지 않는 양자화 워크플로)





conversion_summary_dir
dump_graphviz_dir
dump_graphviz_video
제거됨(대신 네트론 또는 visualize.py를 사용하여 모델을 시각화)


output_format
drop_control_dependency
제거됨(TF2에서 지원되지 않는 특성)

예제

이제 레거시 TF1 모델을 TF1/TF2 SavedModels로 변환한 다음 TF2 TFLite 모델로 변환하는 몇 가지 예제를 살펴보겠습니다.

설치하기

필요한 TensorFlow 가져오기로 시작합니다.

import tensorflow as tf
import tensorflow.compat.v1 as tf1
import numpy as np

import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

import shutil
def remove_dir(path):
  try:
    shutil.rmtree(path)
  except:
    pass
2022-12-14 20:54:34.598561: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:54:34.598658: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:54:34.598666: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

필요한 모든 TF1 모델 형식을 생성합니다.

# Create a TF1 SavedModel
SAVED_MODEL_DIR = "tf_saved_model/"
remove_dir(SAVED_MODEL_DIR)
with tf1.Graph().as_default() as g:
  with tf1.Session() as sess:
    input = tf1.placeholder(tf.float32, shape=(3,), name='input')
    output = input + 2
    # print("result: ", sess.run(output, {input: [0., 2., 4.]}))
    tf1.saved_model.simple_save(
        sess, SAVED_MODEL_DIR,
        inputs={'input': input}, 
        outputs={'output': output})
print("TF1 SavedModel path: ", SAVED_MODEL_DIR)

# Create a TF1 Keras model
KERAS_MODEL_PATH = 'tf_keras_model.h5'
model = tf1.keras.models.Sequential([
    tf1.keras.layers.InputLayer(input_shape=(128, 128, 3,), name='input'),
    tf1.keras.layers.Dense(units=16, input_shape=(128, 128, 3,), activation='relu'),
    tf1.keras.layers.Dense(units=1, name='output')
])
model.save(KERAS_MODEL_PATH, save_format='h5')
print("TF1 Keras Model path: ", KERAS_MODEL_PATH)

# Create a TF1 frozen GraphDef model
GRAPH_DEF_MODEL_PATH = tf.keras.utils.get_file(
    'mobilenet_v1_0.25_128',
    origin='https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.25_128_frozen.tgz',
    untar=True,
) + '/frozen_graph.pb'

print("TF1 frozen GraphDef path: ", GRAPH_DEF_MODEL_PATH)
TF1 SavedModel path:  tf_saved_model/
TF1 Keras Model path:  tf_keras_model.h5
Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.25_128_frozen.tgz
2617289/2617289 [==============================] - 0s 0us/step
TF1 frozen GraphDef path:  /home/kbuilder/.keras/datasets/mobilenet_v1_0.25_128/frozen_graph.pb

1. TF1 SavedModel을 TFLite 모델로 변환하기

전: TF1로 변환하기

다음은 TF1 스타일의 TFlite 변환에 사용하는 일반적인 코드입니다.

converter = tf1.lite.TFLiteConverter.from_saved_model(
    saved_model_dir=SAVED_MODEL_DIR,
    input_arrays=['input'],
    input_shapes={'input' : [3]}
)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
converter.change_concat_input_ranges = True
tflite_model = converter.convert()
# Ignore warning: "Use '@tf.function' or '@defun' to decorate the function."
2022-12-14 20:54:39.163912: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-14 20:54:39.163946: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2022-12-14 20:54:39.163953: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:371] Ignored change_concat_input_ranges.

후: TF2로 변환하기

더 작은 v2 변환기 플래그가 설정을 사용하여 TF1 SavedModel을 TFLite 모델로 직접 변환합니다.

# Convert TF1 SavedModel to a TFLite model.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir=SAVED_MODEL_DIR)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
tflite_model = converter.convert()
2022-12-14 20:54:39.215359: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-14 20:54:39.215403: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.

2. TF1 Keras 모델 파일을 TFLite 모델로 변환하기

전: TF1로 변환하기

다음은 TF1 스타일의 TFlite 변환에 사용하는 일반적인 코드입니다.

converter = tf1.lite.TFLiteConverter.from_keras_model_file(model_file=KERAS_MODEL_PATH)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
converter.change_concat_input_ranges = True
tflite_model = converter.convert()
2022-12-14 20:54:40.216305: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-14 20:54:40.216340: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2022-12-14 20:54:40.216347: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:371] Ignored change_concat_input_ranges.

후: TF2로 변환하기

먼저 TF1 Keras 모델 파일을 TF2 SavedModel로 변환한 다음 더 작은 v2 변환기 플래그 설정을 사용하여 TFLite 모델로 변환합니다.

# Convert TF1 Keras model file to TF2 SavedModel.
model = tf.keras.models.load_model(KERAS_MODEL_PATH)
model.save(filepath='saved_model_2/')

# Convert TF2 SavedModel to a TFLite model.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_2/')
tflite_model = converter.convert()
2022-12-14 20:54:40.889897: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-14 20:54:40.889935: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.

3. TF1 고정 GraphDef를 TFLite 모델로 변환하기

전: TF1로 변환하기

다음은 TF1 스타일의 TFlite 변환에 사용하는 일반적인 코드입니다.

converter = tf1.lite.TFLiteConverter.from_frozen_graph(
    graph_def_file=GRAPH_DEF_MODEL_PATH,
    input_arrays=['input'],
    input_shapes={'input' : [1, 128, 128, 3]},
    output_arrays=['MobilenetV1/Predictions/Softmax'],
)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
converter.change_concat_input_ranges = True
tflite_model = converter.convert()
2022-12-14 20:54:41.080555: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-14 20:54:41.080592: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2022-12-14 20:54:41.080601: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:371] Ignored change_concat_input_ranges.

후: TF2로 변환하기

먼저 TF1 고정 GraphDef를 TF1 SavedModel로 변환한 다음 더 작은 v2 변환기 플래그 설정을 사용하여 TFLite 모델로 변환합니다.

## Convert TF1 frozen Graph to TF1 SavedModel.

# Load the graph as a v1.GraphDef
import pathlib
gdef = tf.compat.v1.GraphDef()
gdef.ParseFromString(pathlib.Path(GRAPH_DEF_MODEL_PATH).read_bytes())

# Convert the GraphDef to a tf.Graph
with tf.Graph().as_default() as g:
  tf.graph_util.import_graph_def(gdef, name="")

# Look up the input and output tensors.
input_tensor = g.get_tensor_by_name('input:0') 
output_tensor = g.get_tensor_by_name('MobilenetV1/Predictions/Softmax:0')

# Save the graph as a TF1 Savedmodel
remove_dir('saved_model_3/')
with tf.compat.v1.Session(graph=g) as s:
  tf.compat.v1.saved_model.simple_save(
      session=s,
      export_dir='saved_model_3/',
      inputs={'input':input_tensor},
      outputs={'output':output_tensor})

# Convert TF1 SavedModel to a TFLite model.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_3/')
converter.optimizations = {tf.lite.Optimize.DEFAULT}
tflite_model = converter.convert()
2022-12-14 20:54:41.890371: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-14 20:54:41.890408: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.

추가 자료

  • 워크플로 및 최신 특성에 대한 자세한 내용은 TFLite 가이드를 참고하세요.
  • TF1 코드 또는 레거시 TF1 모델 형식(Keras .h5 파일, 고정 GraphDef .pb 등)을 사용하는 경우 코드를 업데이트하고 모델을 TF2 SavedModel 모델 형식으로 마이그레이션하세요.