Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Transfer stylu artystycznego z TensorFlow Lite

Zobacz na TensorFlow.org Wyświetl źródło na GitHub Pobierz notatnik

Jednym z najbardziej ekscytujących osiągnięć w uczeniu głębokim, które pojawiły się w ostatnim czasie, jest transfer stylu artystycznego lub możliwość stworzenia nowego obrazu, znanego jako pastisz , w oparciu o dwa obrazy wejściowe: jeden reprezentujący styl artystyczny i jeden przedstawiający treść.

Przykład transferu stylu

Korzystając z tej techniki, możemy tworzyć piękne nowe dzieła sztuki w różnych stylach.

Przykład transferu stylu

Jeśli jesteś nowym użytkownikiem TensorFlow Lite i pracujesz z systemem Android, zalecamy zapoznanie się z poniższymi przykładowymi aplikacjami, które mogą pomóc w rozpoczęciu pracy.

Przykład systemu Android Przykład iOS

Jeśli korzystasz z platformy innej niż Android lub iOS lub znasz już interfejsy API TensorFlow Lite , możesz skorzystać z tego samouczka, aby dowiedzieć się, jak zastosować transfer stylów do dowolnej pary treści i obrazu stylu za pomocą wstępnie wyszkolonego TensorFlow Lite Model. Możesz użyć modelu, aby dodać transfer stylów do własnych aplikacji mobilnych.

Model jest typu open source w serwisie GitHub . Możesz ponownie wytrenować model z innymi parametrami (np. Zwiększyć wagi warstw zawartości, aby obraz wyjściowy wyglądał bardziej jak obraz zawartości).

Zrozum architekturę modelu

Architektura modelu

Ten model transferu stylu artystycznego składa się z dwóch podmodeli:

  1. Model Prediciton Model : Sieć neuronowa oparta na MobilenetV2, która przenosi obraz w stylu wejściowym do 100-wymiarowego wektora wąskiego gardła.
  2. Model transformacji stylu : sieć neuronowa, która stosuje stylowy wektor wąskiego gardła do obrazu treści i tworzy stylizowany obraz.

Jeśli Twoja aplikacja musi obsługiwać tylko ustalony zestaw obrazów stylów, możesz z wyprzedzeniem obliczyć ich wektory wąskiego gardła stylu i wykluczyć model przewidywania stylu z pliku binarnego aplikacji.

Ustawiać

Importuj zależności.

import tensorflow as tf
print(tf.__version__)
2.3.0

import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

Pobierz zawartość i obrazy stylów oraz wstępnie wyszkolone modele TensorFlow Lite.

content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')

style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://hub.tensorflow.google.cn/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://hub.tensorflow.google.cn/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg
458752/458481 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg
114688/108525 [===============================] - 0s 0us/step
Downloading data from https://hub.tensorflow.google.cn/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
2834432/2828838 [==============================] - 0s 0us/step
Downloading data from https://hub.tensorflow.google.cn/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
286720/284398 [==============================] - 0s 0us/step

Przetwórz wstępnie dane wejściowe

  • Obraz zawartości i obraz stylu muszą być obrazami RGB z wartościami pikseli będącymi liczbami float32 z przedziału [0..1].
  • Rozmiar obrazu stylu musi wynosić (1, 256, 256, 3). Centralnie przycinamy obraz i zmieniamy jego rozmiar.
  • Obraz treści musi mieć postać (1, 384, 384, 3). Centralnie przycinamy obraz i zmieniamy jego rozmiar.
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
  img = tf.io.read_file(path_to_img)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = img[tf.newaxis, :]

  return img

# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
  # Resize the image so that the shorter dimension becomes 256px.
  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
  short_dim = min(shape)
  scale = target_dim / short_dim
  new_shape = tf.cast(shape * scale, tf.int32)
  image = tf.image.resize(image, new_shape)

  # Central crop the image.
  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

  return image

# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)

# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)

print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3)
Content Image Shape: (1, 384, 384, 3)

Wizualizuj dane wejściowe

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')

png

Uruchom transfer stylowy dzięki TensorFlow Lite

Przewidywanie stylu

# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_predict_path)

  # Set model input.
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)

  # Calculate style bottleneck.
  interpreter.invoke()
  style_bottleneck = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)

Zmień styl

# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_transform_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)

# Visualize the output.
imshow(stylized_image, 'Stylized Image')

png

Mieszanie stylów

Możemy połączyć styl obrazu treści ze stylizowanym wyjściem, co z kolei sprawi, że wynik będzie wyglądał bardziej jak obraz zawartości.

# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
    preprocess_image(content_image, 256)
    )
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5 

# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \

                           + (1 - content_blending_ratio) * style_bottleneck

# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
                                             preprocessed_content_image)

# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')

png

Testy wydajności

Numery testów wydajności są generowane za pomocą opisanego tutaj narzędzia.

Nazwa modelu Rozmiar modelu Urządzenie NNAPI procesor GPU
Model przewidywania stylu (int8) 2,8 Mb Pixel 3 (Android 10) 142ms 14 ms
Pixel 4 (Android 10) 5,2 ms 6,7 ms
iPhone XS (iOS 12.4.1) 10,7 ms
Model transformacji stylu (int8) 0,2 Mb Pixel 3 (Android 10) 540 ms
Pixel 4 (Android 10) 405 ms
iPhone XS (iOS 12.4.1) 251ms
Model przewidywania stylu (float16) 4,7 Mb Pixel 3 (Android 10) 86ms 28 ms 9,1 ms
Pixel 4 (Android 10) 32 ms 12 ms 10 ms
Model transferu stylu (float16) 0,4 Mb Pixel 3 (Android 10) 1095 ms 545ms 42ms
Pixel 4 (Android 10) 603ms 377ms 42 ms

* 4 używane nici.
** 2 wątki na iPhonie dla najlepszej wydajności.