Transfert de style artistique avec TensorFlow Lite

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier Voir le modèle TF Hub

L' un des développements les plus passionnants dans l' apprentissage en profondeur de sortir est récemment le transfert de style artistique , ou la possibilité de créer une nouvelle image, connue sous le nom d' un pastiche , basé sur deux images d'entrée: un représentant le style artistique et un représentant le contenu.

Exemple de transfert de style

En utilisant cette technique, nous pouvons générer de belles nouvelles œuvres d'art dans une gamme de styles.

Exemple de transfert de style

Si vous débutez avec TensorFlow Lite et que vous travaillez avec Android, nous vous recommandons d'explorer les exemples d'applications suivants qui peuvent vous aider à démarrer.

Par exemple Android exemple iOS

Si vous utilisez une autre plate - forme que Android ou iOS, ou si vous êtes déjà familier avec les tensorflow API Lite , vous pouvez suivre ce tutoriel pour apprendre à appliquer le transfert de style sur une paire de contenu et de l' image de style avec une tensorflow pré-formation Lite maquette. Vous pouvez utiliser le modèle pour ajouter un transfert de style à vos propres applications mobiles.

Le modèle est open source sur GitHub . Vous pouvez recycler le modèle avec différents paramètres (par exemple, augmenter le poids des couches de contenu pour que l'image de sortie ressemble davantage à l'image de contenu).

Comprendre l'architecture du modèle

Architecture du modèle

Ce modèle de transfert de style artistique se compose de deux sous-modèles :

  1. Style Prediciton Modèle: Un réseau neuronal basé MobilenetV2 qui prend une image de style d'entrée à un vecteur de goulot d' étranglement de style 100-dimension.
  2. Style de transformation de modèle: Un réseau de neurones qui s'applique un vecteur de goulot d' étranglement de style à une image contenu et crée une image stylisée.

Si votre application doit uniquement prendre en charge un ensemble fixe d'images de style, vous pouvez calculer leurs vecteurs de goulot d'étranglement de style à l'avance et exclure le modèle de prédiction de style du binaire de votre application.

Installer

Importer des dépendances.

import tensorflow as tf
print(tf.__version__)
2.6.0
import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

Téléchargez les images de contenu et de style, ainsi que les modèles TensorFlow Lite pré-entraînés.

content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')

style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg
458752/458481 [==============================] - 0s 0us/step
466944/458481 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg
114688/108525 [===============================] - 0s 0us/step
122880/108525 [=================================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
2834432/2828838 [==============================] - 0s 0us/step
2842624/2828838 [==============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
286720/284398 [==============================] - 0s 0us/step
294912/284398 [===============================] - 0s 0us/step

Pré-traiter les entrées

  • L'image de contenu et l'image de style doivent être des images RVB avec des valeurs de pixel étant des nombres float32 compris entre [0..1].
  • La taille de l'image de style doit être (1, 256, 256, 3). Nous recadrons l'image au centre et la redimensionnons.
  • L'image du contenu doit être (1, 384, 384, 3). Nous recadrons l'image au centre et la redimensionnons.
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
  img = tf.io.read_file(path_to_img)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = img[tf.newaxis, :]

  return img

# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
  # Resize the image so that the shorter dimension becomes 256px.
  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
  short_dim = min(shape)
  scale = target_dim / short_dim
  new_shape = tf.cast(shape * scale, tf.int32)
  image = tf.image.resize(image, new_shape)

  # Central crop the image.
  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

  return image

# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)

# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)

print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3)
Content Image Shape: (1, 384, 384, 3)

Visualiser les entrées

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')

png

Transfert de style d'exécution avec TensorFlow Lite

Prédiction de style

# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_predict_path)

  # Set model input.
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)

  # Calculate style bottleneck.
  interpreter.invoke()
  style_bottleneck = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)

Transformation de style

# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_transform_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)

# Visualize the output.
imshow(stylized_image, 'Stylized Image')

png

Mélange de styles

Nous pouvons mélanger le style de l'image de contenu dans la sortie stylisée, ce qui à son tour fait que la sortie ressemble davantage à l'image de contenu.

# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
    preprocess_image(content_image, 256)
    )
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5

# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \
                           + (1 - content_blending_ratio) * style_bottleneck

# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
                                             preprocessed_content_image)

# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')

png

Références de performance

Les numéros de référence de performance sont générés avec l'outil décrit ici .

Nom du modèle Taille du modèle Appareil NNAPI CPU GPU
Modèle de prédiction de style (int8) 2,8 Mo Pixel 3 (Android 10) 142ms 14ms
Pixel 4 (Android 10) 5.2ms 6.7ms
iPhone XS (iOS 12.4.1) 10,7 ms
Modèle de transformation de style (int8) 0,2 Mo Pixel 3 (Android 10) 540ms
Pixel 4 (Android 10) 405 ms
iPhone XS (iOS 12.4.1) 251ms
Modèle de prédiction de style (float16) 4,7 Mo Pixel 3 (Android 10) 86ms 28ms 9.1ms
Pixel 4 (Android 10) 32ms 12ms 10 ms
Modèle de transfert de style (float16) 0,4 Mo Pixel 3 (Android 10) 1095ms 545 ms 42 ms
Pixel 4 (Android 10) 603ms 377ms 42 ms

* 4 fils utilisés.
** 2 threads sur iPhone pour les meilleures performances.