Visualizza su TensorFlow.org | Esegui in Google Colab | Visualizza la fonte su GitHub | Scarica taccuino | Vedi il modello del mozzo TF |
Uno dei più interessanti sviluppi in profondità di apprendimento di uscire di recente è il trasferimento artistico stile , o la possibilità di creare una nuova immagine, conosciuto come un pastiche , sulla base di due immagini in ingresso: uno che rappresenta lo stile artistico e uno che rappresenta il contenuto.
Usando questa tecnica, possiamo generare bellissime nuove opere d'arte in una gamma di stili.
Se non conosci TensorFlow Lite e stai lavorando con Android, ti consigliamo di esplorare le seguenti applicazioni di esempio che possono aiutarti a iniziare.
Se si utilizza una piattaforma diversa da Android o iOS, o si ha già familiarità con i Lite API tensorflow , è possibile seguire questo tutorial per imparare ad applicare il trasferimento stile su una qualsiasi coppia di stile di immagine contenuti e con un pre-addestrato tensorflow Lite modello. Puoi utilizzare il modello per aggiungere il trasferimento di stile alle tue applicazioni mobili.
Il modello è open-source su GitHub . È possibile riqualificare il modello con parametri diversi (ad esempio, aumentare i pesi dei livelli di contenuto per rendere l'immagine di output più simile all'immagine di contenuto).
Comprendere l'architettura del modello
Questo modello Artistic Style Transfer è costituito da due sottomodelli:
- Stile Prediciton Modello: A-MobilenetV2 basa rete neurale che prende un'immagine stile di input ad un 100-dimensione stile collo di bottiglia vettore.
- Stile Transform Modello: Una rete neurale che prende applicare un vettore stile collo di bottiglia a un'immagine contenuti e crea un'immagine stilizzata.
Se la tua app deve supportare solo un set fisso di immagini di stile, puoi calcolare in anticipo i relativi vettori del collo di bottiglia dello stile ed escludere il modello di previsione dello stile dal file binario della tua app.
Impostare
Importa le dipendenze.
import tensorflow as tf
print(tf.__version__)
2.6.0
import IPython.display as display
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False
import numpy as np
import time
import functools
Scarica il contenuto e le immagini di stile e i modelli TensorFlow Lite pre-addestrati.
content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')
style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg 458752/458481 [==============================] - 0s 0us/step 466944/458481 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg 114688/108525 [===============================] - 0s 0us/step 122880/108525 [=================================] - 0s 0us/step Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite 2834432/2828838 [==============================] - 0s 0us/step 2842624/2828838 [==============================] - 0s 0us/step Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite 286720/284398 [==============================] - 0s 0us/step 294912/284398 [===============================] - 0s 0us/step
Pre-elaborare gli input
- L'immagine del contenuto e l'immagine dello stile devono essere immagini RGB con valori di pixel che sono numeri float32 compresi tra [0..1].
- La dimensione dell'immagine dello stile deve essere (1, 256, 256, 3). Ritagliamo centralmente l'immagine e la ridimensioniamo.
- L'immagine del contenuto deve essere (1, 384, 384, 3). Ritagliamo centralmente l'immagine e la ridimensioniamo.
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
img = tf.io.read_file(path_to_img)
img = tf.io.decode_image(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.float32)
img = img[tf.newaxis, :]
return img
# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
# Resize the image so that the shorter dimension becomes 256px.
shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
short_dim = min(shape)
scale = target_dim / short_dim
new_shape = tf.cast(shape * scale, tf.int32)
image = tf.image.resize(image, new_shape)
# Central crop the image.
image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)
return image
# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)
# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)
print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3) Content Image Shape: (1, 384, 384, 3)
Visualizza gli ingressi
def imshow(image, title=None):
if len(image.shape) > 3:
image = tf.squeeze(image, axis=0)
plt.imshow(image)
if title:
plt.title(title)
plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')
plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')
Esegui il trasferimento dello stile con TensorFlow Lite
Previsione dello stile
# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
# Load the model.
interpreter = tf.lite.Interpreter(model_path=style_predict_path)
# Set model input.
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)
# Calculate style bottleneck.
interpreter.invoke()
style_bottleneck = interpreter.tensor(
interpreter.get_output_details()[0]["index"]
)()
return style_bottleneck
# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)
Trasformazione dello stile
# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
# Load the model.
interpreter = tf.lite.Interpreter(model_path=style_transform_path)
# Set model input.
input_details = interpreter.get_input_details()
interpreter.allocate_tensors()
# Set model inputs.
interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
interpreter.invoke()
# Transform content image.
stylized_image = interpreter.tensor(
interpreter.get_output_details()[0]["index"]
)()
return stylized_image
# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)
# Visualize the output.
imshow(stylized_image, 'Stylized Image')
Miscelazione di stile
Possiamo fondere lo stile dell'immagine del contenuto nell'output stilizzato, che a sua volta rende l'output più simile all'immagine del contenuto.
# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
preprocess_image(content_image, 256)
)
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5
# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \
+ (1 - content_blending_ratio) * style_bottleneck
# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
preprocessed_content_image)
# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')
Benchmark delle prestazioni
I numeri di riferimento delle prestazioni sono generati con lo strumento qui descritto .
Nome del modello | Dimensioni del modello | Dispositivo | NNAPI | processore | GPU |
---|---|---|---|---|---|
Modello di previsione dello stile (int8) | 2,8 Mb | Pixel 3 (Android 10) | 142 ms | 14 ms | |
Pixel 4 (Android 10) | 5.2ms | 6,7 ms | |||
iPhone XS (iOS 12.4.1) | 10,7 ms | ||||
Modello di trasformazione dello stile (int8) | 0.2 Mb | Pixel 3 (Android 10) | 540 ms | ||
Pixel 4 (Android 10) | 405 ms | ||||
iPhone XS (iOS 12.4.1) | 251 ms | ||||
Modello di previsione dello stile (float16) | 4.7 Mb | Pixel 3 (Android 10) | 86 ms | 28 ms | 9.1ms |
Pixel 4 (Android 10) | 32 ms | 12 ms | 10ms | ||
Modello di trasferimento dello stile (float16) | 0,4 Mb | Pixel 3 (Android 10) | 1095 ms | 545 ms | 42 ms |
Pixel 4 (Android 10) | 603 ms | 377 ms | 42 ms |
* 4 fili utilizzati.
** 2 thread su iPhone per le migliori prestazioni.