Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Głęboka konwolucyjna generatywna sieć adwersarzy

Zobacz na TensorFlow.org Wyświetl źródło na GitHub Pobierz notatnik

Ten samouczek pokazuje, jak generować obrazy odręcznych cyfr za pomocą sieci Deep Convolutional Generative Adversarial Network (DCGAN). Kod jest napisany przy użyciu interfejsu Keras Sequential API z tf.GradientTape szkoleniową tf.GradientTape .

Co to są GAN?

Generative Adversarial Networks (GAN) to jeden z najciekawszych pomysłów współczesnej informatyki. Dwa modele są trenowane jednocześnie w procesie kontradyktoryjnym. Generator („artysta”) uczy się tworzyć obrazy, które wyglądają na prawdziwe, podczas gdy dyskryminator („krytyk sztuki”) uczy się odróżniać rzeczywiste obrazy od podróbek.

Schemat generatora i dyskryminatora

Podczas treningu generator stopniowo staje się lepszy w tworzeniu obrazów, które wyglądają jak prawdziwe, podczas gdy dyskryminator coraz lepiej je rozróżnia. Proces osiąga równowagę, gdy dyskryminator nie jest już w stanie odróżnić prawdziwych obrazów od podróbek.

Drugi schemat generatora i dyskryminatora

Ten notatnik przedstawia ten proces w zestawie danych MNIST. Poniższa animacja przedstawia serię obrazów generowanych przez generator podczas trenowania przez 50 epok. Obrazy zaczynają się jako przypadkowy szum iz czasem coraz bardziej przypominają odręczne cyfry.

przykładowe wyjście

Aby dowiedzieć się więcej o GAN, zalecamy kurs MIT Wprowadzenie do głębokiego uczenia się .

Ustawiać

import tensorflow as tf
tf.__version__
'2.3.0'
# To generate GIFs
pip install -q imageio
pip install -q git+https://github.com/tensorflow/docs
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

Załaduj i przygotuj zbiór danych

Użyjesz zbioru danych MNIST do szkolenia generatora i dyskryminatora. Generator wygeneruje odręczne cyfry przypominające dane MNIST.

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Utwórz modele

Zarówno generator, jak i dyskryminator są definiowane za pomocą interfejsu Keras Sequential API .

Generator

Generator wykorzystuje tf.keras.layers.Conv2DTranspose (upsampling) do tworzenia obrazu z ziarna (szum losowy). Zacznij od warstwy Dense , która pobiera to ziarno jako dane wejściowe, a następnie kilkakrotnie zwiększaj próbkowanie, aż osiągniesz żądany rozmiar obrazu 28x28x1. Zwróć uwagę na aktywację tf.keras.layers.LeakyReLU dla każdej warstwy, z wyjątkiem warstwy wyjściowej, która używa tanh.

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

Użyj (jeszcze nieprzeszkolonego) generatora, aby utworzyć obraz.

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f2729b9f6d8>

png

Dyskryminator

Dyskryminator to klasyfikator obrazów oparty na CNN.

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

Użyj (jeszcze niewyszkolonego) dyskryminatora, aby sklasyfikować wygenerowane obrazy jako prawdziwe lub fałszywe. Model zostanie przeszkolony, aby przedstawiał dodatnie wartości dla obrazów rzeczywistych i wartości ujemne dla fałszywych obrazów.

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[0.0003284]], shape=(1, 1), dtype=float32)

Zdefiniuj straty i optymalizatory

Zdefiniuj funkcje strat i optymalizatory dla obu modeli.

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

Utrata dyskryminatora

Ta metoda określa ilościowo, jak dobrze dyskryminator jest w stanie odróżnić prawdziwe obrazy od podróbek. Porównuje przewidywania dyskryminatora dotyczące rzeczywistych obrazów z tablicą jedynek, a przewidywania dyskryminatora dotyczące fałszywych (generowanych) obrazów z tablicą zer.

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

Utrata generatora

Strata generatora określa ilościowo, jak dobrze był on w stanie oszukać dyskryminator. Intuicyjnie, jeśli generator działa dobrze, dyskryminator sklasyfikuje fałszywe obrazy jako rzeczywiste (lub 1). Tutaj porównamy decyzje dyskryminatorów na wygenerowanych obrazach z tablicą jedynek.

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

Dyskryminator i optymalizatory generatora są różne, ponieważ będziemy szkolić dwie sieci oddzielnie.

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Zapisz punkty kontrolne

Ten notatnik pokazuje również, jak zapisywać i przywracać modele, co może być pomocne w przypadku przerwania długotrwałego zadania szkoleniowego.

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

Zdefiniuj pętlę treningową

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

Pętla szkoleniowa rozpoczyna się, gdy generator otrzymuje losowe ziarno jako dane wejściowe. To ziarno jest używane do tworzenia obrazu. Dyskryminator jest następnie używany do klasyfikowania obrazów rzeczywistych (pochodzących ze zbioru uczącego) i obrazów fałszywych (wytwarzanych przez generator). Straty są obliczane dla każdego z tych modeli, a gradienty są używane do aktualizacji generatora i dyskryminatora.

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

Generuj i zapisuj obrazy

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

Wytrenuj model

Wywołaj metodę train() zdefiniowaną powyżej, aby jednocześnie uczyć generator i dyskryminator. Uwaga, szkolenie GAN może być trudne. Ważne jest, aby generator i dyskryminator nie przeciążały się nawzajem (np. Trenowały w podobnym tempie).

Na początku treningu wygenerowane obrazy wyglądają jak losowy szum. W miarę postępu treningu generowane cyfry będą wyglądać coraz bardziej realistycznie. Po około 50 epokach przypominają cyfry MNIST. Może to zająć około minuty / epoki przy domyślnych ustawieniach Colab.

train(train_dataset, EPOCHS)

png

Przywróć ostatni punkt kontrolny.

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2729bc3128>

Utwórz GIF

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

png

Użyj imageio aby stworzyć animowany gif, używając obrazów zapisanych podczas treningu.

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

gif

Następne kroki

W tym samouczku przedstawiono pełny kod niezbędny do napisania i szkolenia GAN. W następnym kroku możesz poeksperymentować z innym zestawem danych, na przykład zestawem danych Large-scale Celeb Faces Attributes (CelebA) dostępnym w Kaggle . Aby dowiedzieć się więcej o sieciach GAN, polecamy samouczek NIPS 2016: Generative Adversarial Networks .