Super résolution avec TensorFlow Lite

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier Voir le modèle TF Hub

Aperçu

La tâche consistant à récupérer une image haute résolution (HR) à partir de son homologue basse résolution est communément appelée Single Image Super Resolution (SISR).

Le modèle utilisé ici est ESRGAN ( ESRGAN: Super-résolution améliorée génératives accusatoires Networks ). Et nous allons utiliser TensorFlow Lite pour exécuter l'inférence sur le modèle pré-entraîné.

Le modèle TFLite est converti à partir de cette mise en œuvre hébergé sur TF Hub. Notez que le modèle que nous avons converti suréchantillonne une image basse résolution 50x50 en une image haute résolution 200x200 (facteur d'échelle=4). Si vous souhaitez une taille d'entrée ou un facteur d'échelle différent, vous devez reconvertir ou réentraîner le modèle d'origine.

Installer

Installons d'abord les bibliothèques requises.

pip install matplotlib tensorflow tensorflow-hub

Importer des dépendances.

import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
2.7.0

Téléchargez et convertissez le modèle ESRGAN

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])
def f(input):
  return concrete_func(input);

converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
  f.write(tflite_model)

esrgan_model_path = './ESRGAN.tflite'
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 335). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets
INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets
2021-11-16 12:15:19.621471: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-11-16 12:15:19.621517: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

Téléchargez une image test (tête d'insecte).

test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')
Downloading data from https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg
16384/6432 [============================================================================] - 0s 0us/step

Générez une image en super résolution à l'aide de TensorFlow Lite

lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()

# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)

Visualisez le résultat

lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)        
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());

bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)   
plt.title('Bicubic')
plt.imshow(bicubic.numpy());

png

png

Références de performance

Les numéros de référence de performance sont générés avec l'outil décrit ici .

Nom du modèle Taille du modèle Appareil CPU GPU
super résolution (ESRGAN) 4,8 Mo Pixel 3 586,8 ms* 128,6 ms
Pixel 4 385,1 ms* 130.3ms

* 4 fils utilisés