Ver en TensorFlow.org | Ejecutar en Google Colab | Ver fuente en GitHub | Descargar cuaderno | Ver modelo TF Hub |
Descripción general
La tarea de recuperar una imagen de alta resolución (HR) de su contraparte de baja resolución se denomina comúnmente superresolución de imagen única (SISR).
El modelo utilizado aquí es ESRGAN ( ESRGAN: Enhanced Super-Resolución generativos Acusatorios Redes ). Y usaremos TensorFlow Lite para ejecutar inferencias en el modelo previamente entrenado.
El modelo TFLite se convierte de esta aplicación alojada en TF concentradores. Tenga en cuenta que el modelo que convertimos muestra una imagen de baja resolución de 50x50 a una imagen de alta resolución de 200x200 (factor de escala = 4). Si desea un tamaño de entrada o factor de escala diferente, debe volver a convertir o volver a entrenar el modelo original.
Configuración
Primero instalemos las bibliotecas necesarias.
pip install matplotlib tensorflow tensorflow-hub
Importar dependencias.
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
2.7.0
Descargue y convierta el modelo ESRGAN
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])
def f(input):
return concrete_func(input);
converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
f.write(tflite_model)
esrgan_model_path = './ESRGAN.tflite'
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 335). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets 2021-11-16 12:15:19.621471: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2021-11-16 12:15:19.621517: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
Descarga una imagen de prueba (cabeza de insecto).
test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')
Downloading data from https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg 16384/6432 [============================================================================] - 0s 0us/step
Genera una imagen de súper resolución con TensorFlow Lite
lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()
# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)
Visualiza el resultado
lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());
bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)
plt.title('Bicubic')
plt.imshow(bicubic.numpy());
Benchmarks de desempeño
Números de referencia de rendimiento son generados con la herramienta descrita aquí .
Nombre del modelo | Tamaño del modelo | Dispositivo | UPC | GPU |
---|---|---|---|---|
superresolución (ESRGAN) | 4,8 Mb | Pixel 3 | 586,8 ms * | 128,6 ms |
Pixel 4 | 385,1 ms * | 130,3 ms |
* 4 hilos utilizados