Cette page a été traduite par l'API Cloud Translation.
Switch to English

Réseau Adversaire Génératif Convolutionnel Profond

Voir sur TensorFlow.org Afficher la source sur GitHub Télécharger le carnet

Ce didacticiel montre comment générer des images de chiffres manuscrits à l'aide d'un réseau DCGAN ( Deep Convolutional Generative Adversarial Network ). Le code est écrit à l'aide de l' API séquentielle Keras avec une boucle d'entraînement tf.GradientTape .

Que sont les GAN?

Les Réseaux Adversaires Génératifs (GAN) sont aujourd'hui l'une des idées les plus intéressantes de l'informatique. Deux modèles sont formés simultanément par un processus contradictoire. Un générateur («l'artiste») apprend à créer des images qui semblent réelles, tandis qu'un discriminateur («le critique d'art») apprend à distinguer les images réelles des contrefaçons.

Un schéma d'un générateur et d'un discriminateur

Au cours de la formation, le générateur devient progressivement meilleur pour créer des images qui semblent réelles, tandis que le discriminateur devient meilleur pour les distinguer. Le processus atteint l'équilibre lorsque le discriminateur ne peut plus distinguer les images réelles des contrefaçons.

Un deuxième schéma d'un générateur et d'un discriminateur

Ce bloc-notes illustre ce processus sur l'ensemble de données MNIST. L'animation suivante montre une série d'images produites par le générateur tel qu'il a été formé pendant 50 époques. Les images commencent par un bruit aléatoire et ressemblent de plus en plus à des chiffres écrits à la main au fil du temps.

sortie d'échantillon

Pour en savoir plus sur les GAN, nous vous recommandons le cours Intro to Deep Learning du MIT.

Installer

import tensorflow as tf
tf.__version__
'2.3.0'
# To generate GIFs
pip install -q imageio
pip install -q git+https://github.com/tensorflow/docs
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

Charger et préparer le jeu de données

Vous utiliserez le jeu de données MNIST pour entraîner le générateur et le discriminateur. Le générateur générera des chiffres manuscrits ressemblant aux données MNIST.

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Créer les modèles

Le générateur et le discriminateur sont définis à l'aide de l' API séquentielle Keras .

Le générateur

Le générateur utilise des tf.keras.layers.Conv2DTranspose (suréchantillonnage) pour produire une image à partir d'une graine (bruit aléatoire). Commencez avec un calque Dense qui prend cette graine comme entrée, puis suréchantillonnez plusieurs fois jusqu'à ce que vous atteigniez la taille d'image souhaitée de 28x28x1. Notez l'activation de tf.keras.layers.LeakyReLU pour chaque couche, sauf la couche de sortie qui utilise tanh.

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

Utilisez le générateur (encore inexpérimenté) pour créer une image.

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f2729b9f6d8>

png

Le discriminateur

Le discriminateur est un classificateur d'image basé sur CNN.

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

Utilisez le discriminateur (encore inexpérimenté) pour classer les images générées comme réelles ou fausses. Le modèle sera formé pour produire des valeurs positives pour les images réelles et des valeurs négatives pour les fausses images.

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[0.0003284]], shape=(1, 1), dtype=float32)

Définir la perte et les optimiseurs

Définissez les fonctions de perte et les optimiseurs pour les deux modèles.

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

Perte du discriminateur

Cette méthode quantifie à quel point le discriminateur est capable de distinguer les images réelles des fausses. Il compare les prédictions du discriminateur sur des images réelles à un tableau de 1 et les prédictions du discriminateur sur de fausses images (générées) à un tableau de 0.

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

Perte du générateur

La perte du générateur quantifie à quel point il a réussi à tromper le discriminateur. Intuitivement, si le générateur fonctionne bien, le discriminateur classera les fausses images comme réelles (ou 1). Ici, nous comparerons les décisions des discriminateurs sur les images générées à un tableau de 1.

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

Le discriminateur et les optimiseurs de générateur sont différents puisque nous allons former deux réseaux séparément.

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Enregistrer les points de contrôle

Ce bloc-notes montre également comment enregistrer et restaurer des modèles, ce qui peut être utile en cas d'interruption d'une tâche d'entraînement de longue durée.

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

Définir la boucle d'entraînement

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

La boucle d'apprentissage commence avec le générateur recevant une graine aléatoire en entrée. Cette graine est utilisée pour produire une image. Le discriminateur est ensuite utilisé pour classer des images réelles (tirées de l'ensemble d'apprentissage) et des images factices (produites par le générateur). La perte est calculée pour chacun de ces modèles, et les gradients sont utilisés pour mettre à jour le générateur et le discriminateur.

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

Générer et enregistrer des images

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

Former le modèle

Appelez la méthode train() définie ci-dessus pour entraîner simultanément le générateur et le discriminateur. Notez que la formation des GAN peut être délicate. Il est important que le générateur et le discriminateur ne se maîtrisent pas (par exemple, qu'ils s'entraînent à un rythme similaire).

Au début de la formation, les images générées ressemblent à du bruit aléatoire. Au fur et à mesure que la formation progresse, les chiffres générés seront de plus en plus réels. Après environ 50 époques, ils ressemblent à des chiffres MNIST. Cela peut prendre environ une minute / époque avec les paramètres par défaut de Colab.

train(train_dataset, EPOCHS)

png

Restaurez le dernier point de contrôle.

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2729bc3128>

Créer un GIF

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

png

Utilisez imageio pour créer un gif animé en utilisant les images enregistrées pendant l'entraînement.

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

gif

Prochaines étapes

Ce tutoriel a montré le code complet nécessaire pour écrire et entraîner un GAN. Dans l'étape suivante, vous souhaiterez peut-être expérimenter avec un autre jeu de données, par exemple le jeu de données d'attributs Celeb Faces à grande échelle (CelebA) disponible sur Kaggle . Pour en savoir plus sur les GAN, nous vous recommandons le Tutoriel NIPS 2016: Generative Adversarial Networks .