此页面由 Cloud Translation API 翻译。
Switch to English

Künstlerische Stilübertragung mit TensorFlow Lite

Ansicht auf TensorFlow.org Quelle auf GitHub anzeigen Notizbuch herunterladen

Eine der aufregendsten Entwicklungen im Bereich Deep Learning, die in letzter Zeit herausgekommen sind, ist die Übertragung des künstlerischen Stils oder die Fähigkeit, ein neues Bild zu erstellen, das als Pastiche bezeichnet wird und auf zwei Eingabebildern basiert: eines für den künstlerischen Stil und eines für den Inhalt.

Beispiel für die Stilübertragung

Mit dieser Technik können wir schöne neue Kunstwerke in einer Reihe von Stilen erzeugen.

Beispiel für die Stilübertragung

Wenn Sie TensorFlow Lite noch nicht kennen und mit Android arbeiten, empfehlen wir Ihnen, die folgenden Beispielanwendungen zu untersuchen, die Ihnen den Einstieg erleichtern können.

Android Beispiel iOS Beispiel

Wenn Sie eine andere Plattform als Android oder iOS verwenden oder bereits mit den TensorFlow Lite-APIs vertraut sind, können Sie in diesem Lernprogramm lernen, wie Sie mit einem vorab trainierten TensorFlow Lite die Stilübertragung auf jedes Paar von Inhalten und Stilbildern anwenden Modell. Sie können das Modell verwenden, um Ihren eigenen mobilen Anwendungen eine Stilübertragung hinzuzufügen.

Das Modell ist Open-Source auf GitHub . Sie können das Modell mit verschiedenen Parametern neu trainieren (z. B. die Gewichtung der Inhaltsebenen erhöhen, damit das Ausgabebild dem Inhaltsbild ähnlicher wird).

Verstehen Sie die Modellarchitektur

Modellarchitektur

Dieses Artistic Style Transfer-Modell besteht aus zwei Untermodellen:

  1. Style Prediciton Model : Ein MobilenetV2-basiertes neuronales Netzwerk, das ein Bild im Eingabestil in einen Engpassvektor im Stil mit 100 Dimensionen umwandelt.
  2. Stiltransformationsmodell : Ein neuronales Netzwerk, das einen Stilengpassvektor auf ein Inhaltsbild anwendet und ein stilisiertes Bild erstellt.

Wenn Ihre App nur einen festen Satz von Stilbildern unterstützen muss, können Sie deren Stilengpassvektoren im Voraus berechnen und das Stilvorhersagemodell aus der Binärdatei Ihrer App ausschließen.

Installieren

Abhängigkeiten importieren.

import tensorflow as tf
print(tf.__version__)
2.3.0

import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

Laden Sie die Inhalte und Stilbilder sowie die vorgefertigten TensorFlow Lite-Modelle herunter.

content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')

style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://hub.tensorflow.google.cn/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://hub.tensorflow.google.cn/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg
458752/458481 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg
114688/108525 [===============================] - 0s 0us/step
Downloading data from https://hub.tensorflow.google.cn/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
2834432/2828838 [==============================] - 0s 0us/step
Downloading data from https://hub.tensorflow.google.cn/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
286720/284398 [==============================] - 0s 0us/step

Eingaben vorverarbeiten

  • Das Inhaltsbild und das Stilbild müssen RGB-Bilder sein, wobei die Pixelwerte float32-Zahlen zwischen [0..1] sind.
  • Die Stilbildgröße muss (1, 256, 256, 3) sein. Wir beschneiden das Bild zentral und ändern die Größe.
  • Das Inhaltsbild muss (1, 384, 384, 3) sein. Wir beschneiden das Bild zentral und ändern die Größe.
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
  img = tf.io.read_file(path_to_img)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = img[tf.newaxis, :]

  return img

# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
  # Resize the image so that the shorter dimension becomes 256px.
  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
  short_dim = min(shape)
  scale = target_dim / short_dim
  new_shape = tf.cast(shape * scale, tf.int32)
  image = tf.image.resize(image, new_shape)

  # Central crop the image.
  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

  return image

# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)

# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)

print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3)
Content Image Shape: (1, 384, 384, 3)

Visualisieren Sie die Eingaben

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')

png

Führen Sie die Stilübertragung mit TensorFlow Lite aus

Stilvorhersage

# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_predict_path)

  # Set model input.
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)

  # Calculate style bottleneck.
  interpreter.invoke()
  style_bottleneck = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)

Stilumwandlung

# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_transform_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)

# Visualize the output.
imshow(stylized_image, 'Stylized Image')

png

Stilmischung

Wir können den Stil des Inhaltsbilds in die stilisierte Ausgabe mischen, wodurch die Ausgabe eher dem Inhaltsbild ähnelt.

# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
    preprocess_image(content_image, 256)
    )
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5

# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \

                           + (1 - content_blending_ratio) * style_bottleneck

# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
                                             preprocessed_content_image)

# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')

png

Leistungsbenchmarks

Leistungsbenchmarkzahlen werden mit dem hier beschriebenen Tool generiert.

Modellname Modellgröße Gerät NNAPI Zentralprozessor GPU
Stilvorhersagemodell (int8) 2,8 Mb Pixel 3 (Android 10) 142ms 14ms
Pixel 4 (Android 10) 5,2 ms 6,7 ms
iPhone XS (iOS 12.4.1) 10,7 ms
Stiltransformationsmodell (int8) 0,2 Mb Pixel 3 (Android 10) 540 ms
Pixel 4 (Android 10) 405ms
iPhone XS (iOS 12.4.1) 251ms
Stilvorhersagemodell (float16) 4,7 Mb Pixel 3 (Android 10) 86ms 28ms 9,1 ms
Pixel 4 (Android 10) 32ms 12ms 10ms
Stilübertragungsmodell (float16) 0,4 Mb Pixel 3 (Android 10) 1095 ms 545 ms 42ms
Pixel 4 (Android 10) 603 ms 377 ms 42ms

* 4 Threads verwendet.
** 2 Threads auf dem iPhone für die beste Leistung.