Głęboko splotowa generatywna sieć przeciwników

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

Ten samouczek pokazuje, jak generować obrazy odręcznych cyfr przy użyciu sieci DCGAN ( Deep Convolutional Generative Adversarial Network ). Kod jest pisany przy użyciu interfejsu API Keras Sequential z pętlą treningową tf.GradientTape .

Co to są GAN?

Generacyjne sieci przeciwstawne (GAN) to jeden z najciekawszych pomysłów współczesnej informatyki. Dwa modele są trenowane jednocześnie przez proces kontradyktoryjny. Generator („artysta”) uczy się tworzyć obrazy, które wyglądają na prawdziwe, podczas gdy dyskryminator („krytyk sztuki”) uczy się odróżniać prawdziwe obrazy od podróbek.

Schemat generatora i dyskryminatora

Podczas treningu generator stopniowo staje się lepszy w tworzeniu obrazów, które wyglądają realnie, podczas gdy dyskryminator staje się lepszy w ich rozróżnianiu. Proces osiąga równowagę, gdy dyskryminator nie jest już w stanie odróżnić prawdziwych obrazów od podróbek.

Drugi schemat generatora i dyskryminatora

Ten notatnik przedstawia ten proces w zestawie danych MNIST. Poniższa animacja przedstawia serię obrazów wytworzonych przez generator , który był szkolony przez 50 epok. Obrazy zaczynają się od losowego szumu iz czasem coraz bardziej przypominają ręcznie pisane cyfry.

przykładowe wyjście

Aby dowiedzieć się więcej o GAN, zapoznaj się z kursem Wprowadzenie do głębokiego uczenia MIT.

Ustawiać

import tensorflow as tf
tf.__version__
'2.8.0-rc1'
# To generate GIFs
pip install imageio
pip install git+https://github.com/tensorflow/docs
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

Załaduj i przygotuj zbiór danych

Użyjesz zestawu danych MNIST do szkolenia generatora i dyskryminatora. Generator wygeneruje odręczne cyfry przypominające dane MNIST.

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Stwórz modele

Zarówno generator, jak i dyskryminator są definiowane za pomocą interfejsu API Keras Sequential .

Generator

Generator używa tf.keras.layers.Conv2DTranspose (upsampling) do wytworzenia obrazu z nasiona (losowy szum). Zacznij od warstwy Dense , która pobiera to ziarno jako dane wejściowe, a następnie kilkakrotnie próbkuj, aż osiągniesz żądany rozmiar obrazu 28x28x1. Zwróć uwagę na aktywację tf.keras.layers.LeakyReLU dla każdej warstwy, z wyjątkiem warstwy wyjściowej, która używa tanh.

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

Użyj (jeszcze nieprzeszkolonego) generatora, aby utworzyć obraz.

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f6fe7a04b90>

png

Dyskryminator

Dyskryminatorem jest klasyfikator obrazów oparty na CNN.

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

Użyj (jeszcze nieprzeszkolonego) rozróżniacza, aby sklasyfikować wygenerowane obrazy jako prawdziwe lub fałszywe. Model zostanie wytrenowany tak, aby wyświetlał wartości dodatnie dla prawdziwych obrazów i ujemne dla fałszywych obrazów.

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[-0.00339105]], shape=(1, 1), dtype=float32)

Zdefiniuj stratę i optymalizatory

Zdefiniuj funkcje strat i optymalizatory dla obu modeli.

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

Utrata dyskryminatora

Ta metoda określa ilościowo, jak dobrze dyskryminator jest w stanie odróżnić prawdziwe obrazy od podróbek. Porównuje przewidywania dyskryminatora na rzeczywistych obrazach z tablicą jedynek, a przewidywania dyskryminatora na fałszywych (wygenerowanych) obrazach z tablicą zer.

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

Strata generatora

Strata generatora określa ilościowo, jak dobrze był w stanie oszukać dyskryminator. Intuicyjnie, jeśli generator działa dobrze, dyskryminator zaklasyfikuje fałszywe obrazy jako prawdziwe (lub 1). Tutaj porównaj decyzje dyskryminacyjne na wygenerowanych obrazach z tablicą jedynek.

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

Dyskryminator i optymalizatory generatora są różne, ponieważ będziesz szkolić dwie sieci oddzielnie.

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Zapisz punkty kontrolne

Ten notatnik pokazuje również, jak zapisywać i przywracać modele, co może być pomocne w przypadku przerwania długotrwałego zadania szkoleniowego.

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

Zdefiniuj pętlę treningową

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# You will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

Pętla treningowa zaczyna się od otrzymania przez generator losowego ziarna jako danych wejściowych. To ziarno służy do tworzenia obrazu. Dyskryminator jest następnie używany do klasyfikowania obrazów rzeczywistych (pobranych ze zbioru uczącego) oraz obrazów fałszywych (wytworzonych przez generator). Strata jest obliczana dla każdego z tych modeli, a gradienty są wykorzystywane do aktualizacji generatora i dyskryminatora.

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as you go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

Generuj i zapisuj obrazy

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4, 4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

Trenuj modelkę

Wywołaj metodę train() zdefiniowaną powyżej, aby jednocześnie trenować generator i dyskryminator. Uwaga, trenowanie GAN może być trudne. Ważne jest, aby generator i dyskryminator nie przeciążały się nawzajem (np. aby trenowały w podobnym tempie).

Na początku treningu generowane obrazy wyglądają jak losowy szum. W miarę postępu treningu generowane cyfry będą wyglądać coraz bardziej realistycznie. Po około 50 epokach przypominają cyfry MNIST. Przy domyślnych ustawieniach w Colab może to zająć około jednej minuty/epoki.

train(train_dataset, EPOCHS)

png

Przywróć najnowszy punkt kontrolny.

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f6ee8136950>

Utwórz GIF

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

png

Użyj imageio , aby stworzyć animowany gif z obrazów zapisanych podczas treningu.

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

gif

Następne kroki

Ten samouczek pokazał kompletny kod niezbędny do napisania i trenowania GAN. W następnym kroku możesz poeksperymentować z innym zbiorem danych, na przykład zbiorem danych Atrybuty twarzy celebrytów na dużą skalę (CelebA) dostępnym na Kaggle . Aby dowiedzieć się więcej o sieciach GAN, zapoznaj się z samouczkiem NIPS 2016: Generative Adversarial Networks .