TensorFlow Lite를 사용한 초고해상도

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드 TF 허브 모델 보기

개요

저해상도 대응물에서 고해상도(HR) 이미지를 복구하는 작업을 일반적으로 SISR(단일 이미지 초해상도)이라고 합니다.

(여기에 사용 된 모델은 ESRGAN입니다 : 향상된 슈퍼 해상도 제너 적대적 네트워크 ESRGAN을 ). 그리고 TensorFlow Lite를 사용하여 사전 훈련된 모델에서 추론을 실행할 것입니다.

TFLite 모델이 변환되어 구현 TF 허브에 개최했다. 우리가 변환한 모델은 50x50 저해상도 이미지를 200x200 고해상도 이미지로 업샘플링했습니다(배율 계수=4). 다른 입력 크기나 배율을 원하면 원래 모델을 다시 변환하거나 다시 훈련해야 합니다.

설정

먼저 필요한 라이브러리를 설치합시다.

pip install matplotlib tensorflow tensorflow-hub

종속성을 가져옵니다.

import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
2.7.0

ESRGAN 모델 다운로드 및 변환

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])
def f(input):
  return concrete_func(input);

converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
  f.write(tflite_model)

esrgan_model_path = './ESRGAN.tflite'
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 335). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets
INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets
2021-11-16 12:15:19.621471: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-11-16 12:15:19.621517: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

테스트 이미지(곤충 머리)를 다운로드합니다.

test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')
Downloading data from https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg
16384/6432 [============================================================================] - 0s 0us/step

TensorFlow Lite를 사용하여 초해상도 이미지 생성

lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()

# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)

결과 시각화

lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)        
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());

bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)   
plt.title('Bicubic')
plt.imshow(bicubic.numpy());

png

png

성능 벤치마크

성능 벤치 마크 번호는 도구를 사용하여 생성됩니다 여기에 설명 .

모델명 모델 크기 장치 CPU GPU
초해상도(ESRGAN) 4.8MB 픽셀 3 586.8ms* 128.6ms
픽셀 4 385.1ms* 130.3ms

사용 * 4 개 스레드