Deep Convolutional Generative Adversarial Network

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHub הורד מחברת

מדריך זה מדגים כיצד ליצור תמונות של ספרות בכתב יד באמצעות Deep Convolutional Generative Adversarial Network (DCGAN). הקוד נכתב באמצעות ה- Keras Sequential API עם לולאת אימון tf.GradientTape .

מה הם GANs?

רשתות יריבות גנרטיביות (GANs) הן אחד הרעיונות המעניינים ביותר במדעי המחשב כיום. שני מודלים מאומנים בו-זמנית על ידי תהליך אדוורסרי. מחולל ("האמן") לומד ליצור דימויים שנראים אמיתיים, בעוד שמאפיין ("מבקר האמנות") לומד להבדיל בין דימויים אמיתיים לזיופים.

תרשים של מחולל ומבחין

במהלך האימון, המחולל משתפר בהדרגה ביצירת תמונות שנראות אמיתיות, בעוד שהמאבחן משתפר בלהבדיל ביניהן. התהליך מגיע לשיווי משקל כאשר המאבחן אינו יכול עוד להבחין בין תמונות אמיתיות לזיופים.

תרשים שני של מחולל ומבחין

מחברת זו מדגים תהליך זה במערך הנתונים של MNIST. האנימציה הבאה מציגה סדרה של תמונות שהופקו על ידי המחולל כפי שהוא אומן במשך 50 עידנים. התמונות מתחילות כרעש אקראי, ומזכירות יותר ויותר ספרות בכתב יד לאורך זמן.

פלט לדוגמה

למידע נוסף על GANs, ראה קורס מבוא ללמידה עמוקה של MIT.

להכין

import tensorflow as tf
tf.__version__
'2.8.0-rc1'
# To generate GIFs
pip install imageio
pip install git+https://github.com/tensorflow/docs
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

טען והכן את מערך הנתונים

אתה תשתמש במערך הנתונים של MNIST כדי לאמן את המחולל ואת המפלה. המחולל יפיק ספרות בכתב יד הדומות לנתוני MNIST.

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

צור את הדגמים

גם המחולל וגם המפלה מוגדרים באמצעות ה- Keras Sequential API .

הגנרטור

המחולל משתמש tf.keras.layers.Conv2DTranspose (העלאת דגימה) כדי להפיק תמונה מזרע (רעש אקראי). התחל עם שכבה Dense שלוקחת את ה-Seed הזה כקלט, ולאחר מכן הדגימה מספר פעמים עד שתגיע לגודל התמונה הרצוי של 28x28x1. שימו לב להפעלת tf.keras.layers.LeakyReLU עבור כל שכבה, מלבד שכבת הפלט שמשתמשת ב-tanh.

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

השתמש במחולל (עדיין לא מאומן) כדי ליצור תמונה.

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f6fe7a04b90>

png

המפלה

המאבחן הוא סיווג תמונות מבוסס CNN.

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

השתמש באבחון (עדיין לא מאומן) כדי לסווג את התמונות שנוצרו כאמיתיות או מזויפות. המודל יוכשר להפיק ערכים חיוביים עבור תמונות אמיתיות, וערכים שליליים עבור תמונות מזויפות.

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[-0.00339105]], shape=(1, 1), dtype=float32)

הגדר את האובדן והאופטימיזציה

הגדר פונקציות אובדן ואופטימיזציה עבור שני הדגמים.

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

אובדן מפלה

שיטה זו מכמתת עד כמה המאבחן מסוגל להבחין בין תמונות אמיתיות לזיופים. הוא משווה את התחזיות של המאבחן על תמונות אמיתיות למערך של 1, ואת התחזיות של המאבחן על תמונות מזויפות (שנוצרות) למערך של 0.

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

אובדן גנרטור

ההפסד של המחולל מכמת עד כמה הוא הצליח להערים על המפלה. באופן אינטואיטיבי, אם המחולל פועל היטב, המאבחן יסווג את התמונות המזויפות כאמיתיות (או 1). כאן, השווה את החלטות המפלים על התמונות שנוצרו למערך של 1.

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

המאפיין ומייעל המחוללים שונים מכיוון שתכשיר שתי רשתות בנפרד.

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

שמור מחסומים

מחברת זו גם מדגים כיצד לשמור ולשחזר מודלים, מה שיכול להועיל במקרה שמשימת אימון ריצה ארוכה תופסק.

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

הגדר את לולאת האימון

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# You will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

לולאת האימון מתחילה עם הגנרטור שמקבל סיד אקראי כקלט. הזרע הזה משמש לייצור תמונה. לאחר מכן משתמשים באבחנה כדי לסווג תמונות אמיתיות (שמצוירות ממערך האימונים) ומזייפים תמונות (המיוצרות על ידי המחולל). ההפסד מחושב עבור כל אחד מהמודלים הללו, והשיפועים משמשים לעדכון המחולל והאבחנה.

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as you go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

צור ושמור תמונות

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4, 4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

אימון הדגם

קרא לשיטת train() שהוגדרה לעיל כדי לאמן את המחולל והאבחנה בו זמנית. שימו לב, אימון GANs יכול להיות מסובך. חשוב שהמחולל והמבדיל לא ישתלטו זה על זה (למשל, שיתאמנו בקצב דומה).

בתחילת האימון, התמונות שנוצרו נראות כמו רעש אקראי. ככל שהאימון מתקדם, הספרות שנוצרו ייראו אמיתיות יותר ויותר. לאחר כ-50 עידנים, הם דומים לספרות MNIST. זה עשוי להימשך בערך דקה אחת / תקופה עם הגדרות ברירת המחדל ב-Colab.

train(train_dataset, EPOCHS)

png

שחזר את המחסום העדכני ביותר.

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f6ee8136950>

צור GIF

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

png

השתמש imageio כדי ליצור GIF מונפש באמצעות התמונות שנשמרו במהלך האימון.

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

gif

הצעדים הבאים

מדריך זה הראה את הקוד המלא הדרוש לכתיבה והדרכה של GAN. כשלב הבא, אולי תרצה להתנסות עם מערך נתונים אחר, למשל מערך הנתונים בקנה מידה גדול של Celeb Faces Attributes (CelebA) הזמין ב-Kaggle . למידע נוסף על GANs, עיין במדריך NIPS 2016: Generative Adversarial Networks .