הצג באתר TensorFlow.org | הפעל בגוגל קולאב | צפה במקור ב-GitHub | הורד מחברת | ראה דגם TF Hub |
אחת ההתפתחויות המלהיבות ביותר בלמידה עמוקה לצאת הוא לאחרונה העברת סגנון אמנותית , או את היכולת ליצור תמונה חדשה, המכונית פסטיש , המבוססת על שתי תמונות קלט: אחד המייצג את הסגנון האמנותי ואחד המייצגת את התוכן.
באמצעות טכניקה זו, נוכל ליצור יצירות אמנות חדשות ויפות במגוון סגנונות.
אם אתה חדש ב-TensorFlow Lite ועובד עם אנדרואיד, אנו ממליצים לחקור את האפליקציות לדוגמה הבאות שיכולות לעזור לך להתחיל.
אם אתה משתמש בפלטפורמה אחרת מאשר אנדרואיד או iOS, או שאתה כבר מכיר את APIs לייט TensorFlow , אתה יכול לעקוב במדריך זה כדי ללמוד כיצד ליישם העברת בסגנון על כל זוג תוכן ותמונה בסטייל עם לייט מראש מאומן TensorFlow דֶגֶם. אתה יכול להשתמש במודל כדי להוסיף העברת סגנון ליישומים ניידים משלך.
המודל הוא קוד מקור פתוח על GitHub . אתה יכול לאמן מחדש את המודל עם פרמטרים שונים (למשל להגדיל את משקל שכבות התוכן כדי לגרום לתמונת הפלט להיראות יותר כמו תמונת התוכן).
הבן את ארכיטקטורת המודל
מודל העברת סגנון אמנותי זה מורכב משני דגמי משנה:
- סגנון Prediciton דגם: A MobilenetV2 מבוסס רשת עצבית שלוקחת תמונת סגנון קלט כדי וקטור צוואר בקבוק בסגנון 100-ממד.
- סגנון Transform דגם: רשת עצבית שלוקחת להחיל וקטור צוואר בקבוק בסגנון לתמונת תוכן ויוצרת תמונה מסוגננת.
אם האפליקציה שלך צריכה לתמוך רק בסט קבוע של תמונות סגנון, תוכל לחשב את וקטורי צוואר הבקבוק של הסגנון שלהן מראש, ולא לכלול את מודל חיזוי הסגנון מהקובץ הבינארי של האפליקציה שלך.
להכין
תלות בייבוא.
import tensorflow as tf
print(tf.__version__)
2.6.0
import IPython.display as display
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False
import numpy as np
import time
import functools
הורד את תמונות התוכן והסגנון, ואת דגמי TensorFlow Lite שהוכשרו מראש.
content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')
style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg 458752/458481 [==============================] - 0s 0us/step 466944/458481 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg 114688/108525 [===============================] - 0s 0us/step 122880/108525 [=================================] - 0s 0us/step Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite 2834432/2828838 [==============================] - 0s 0us/step 2842624/2828838 [==============================] - 0s 0us/step Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite 286720/284398 [==============================] - 0s 0us/step 294912/284398 [===============================] - 0s 0us/step
עבדו מראש את התשומות
- תמונת התוכן ותמונת הסגנון חייבות להיות תמונות RGB כאשר ערכי הפיקסלים הם מספרי float32 בין [0..1].
- גודל תמונת הסגנון חייב להיות (1, 256, 256, 3). אנו חותכים את התמונה במרכז ומשנים את גודלה.
- תמונת התוכן חייבת להיות (1, 384, 384, 3). אנו חותכים את התמונה במרכז ומשנים את גודלה.
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
img = tf.io.read_file(path_to_img)
img = tf.io.decode_image(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.float32)
img = img[tf.newaxis, :]
return img
# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
# Resize the image so that the shorter dimension becomes 256px.
shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
short_dim = min(shape)
scale = target_dim / short_dim
new_shape = tf.cast(shape * scale, tf.int32)
image = tf.image.resize(image, new_shape)
# Central crop the image.
image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)
return image
# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)
# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)
print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3) Content Image Shape: (1, 384, 384, 3)
דמיינו את התשומות
def imshow(image, title=None):
if len(image.shape) > 3:
image = tf.squeeze(image, axis=0)
plt.imshow(image)
if title:
plt.title(title)
plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')
plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')
הפעל העברת סגנון עם TensorFlow Lite
חיזוי סגנון
# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
# Load the model.
interpreter = tf.lite.Interpreter(model_path=style_predict_path)
# Set model input.
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)
# Calculate style bottleneck.
interpreter.invoke()
style_bottleneck = interpreter.tensor(
interpreter.get_output_details()[0]["index"]
)()
return style_bottleneck
# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)
שינוי סגנון
# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
# Load the model.
interpreter = tf.lite.Interpreter(model_path=style_transform_path)
# Set model input.
input_details = interpreter.get_input_details()
interpreter.allocate_tensors()
# Set model inputs.
interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
interpreter.invoke()
# Transform content image.
stylized_image = interpreter.tensor(
interpreter.get_output_details()[0]["index"]
)()
return stylized_image
# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)
# Visualize the output.
imshow(stylized_image, 'Stylized Image')
מיזוג סגנון
אנחנו יכולים למזג את סגנון תמונת התוכן לתוך הפלט המסוגנן, מה שבתורו גורם לפלט להיראות יותר כמו תמונת התוכן.
# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
preprocess_image(content_image, 256)
)
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5
# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \
+ (1 - content_blending_ratio) * style_bottleneck
# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
preprocessed_content_image)
# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')
מדדי ביצועים
מספרי benchmark ביצועים נוצרים עם הכלי המתואר כאן .
שם המודל | גודל הדגם | התקן | NNAPI | מעבד | GPU |
---|---|---|---|---|---|
מודל חיזוי סגנון (int8) | 2.8 מגה-ביט | Pixel 3 (אנדרואיד 10) | 142ms | 14 אלפיות השנייה | |
Pixel 4 (אנדרואיד 10) | 5.2 אלפיות השנייה | 6.7 אלפיות השנייה | |||
iPhone XS (iOS 12.4.1) | 10.7 אלפיות השנייה | ||||
מודל שינוי סגנון (int8) | 0.2 מגה-ביט | Pixel 3 (אנדרואיד 10) | 540ms | ||
Pixel 4 (אנדרואיד 10) | 405ms | ||||
iPhone XS (iOS 12.4.1) | 251ms | ||||
מודל חיזוי סגנון (float16) | 4.7 מגה-ביט | Pixel 3 (אנדרואיד 10) | 86ms | 28ms | 9.1 אלפיות השנייה |
Pixel 4 (אנדרואיד 10) | 32ms | 12ms | 10 אלפיות השנייה | ||
דגם העברת סגנון (float16) | 0.4 מגה-ביט | Pixel 3 (אנדרואיד 10) | 1095ms | 545ms | 42ms |
Pixel 4 (אנדרואיד 10) | 603ms | 377ms | 42ms |
* 4 חוטים בשימוש.
** 2 שרשורים באייפון לביצועים הטובים ביותר.