Zobacz na TensorFlow.org | Uruchom w Google Colab | Wyświetl źródło na GitHub | Pobierz notatnik |
Ten samouczek pokazuje, jak generować obrazy odręcznych cyfr przy użyciu sieci DCGAN ( Deep Convolutional Generative Adversarial Network ). Kod jest pisany przy użyciu interfejsu API Keras Sequential z pętlą treningową tf.GradientTape
.
Co to są GAN?
Generacyjne sieci przeciwstawne (GAN) to jeden z najciekawszych pomysłów współczesnej informatyki. Dwa modele są trenowane jednocześnie przez proces kontradyktoryjny. Generator („artysta”) uczy się tworzyć obrazy, które wyglądają na prawdziwe, podczas gdy dyskryminator („krytyk sztuki”) uczy się odróżniać prawdziwe obrazy od podróbek.
Podczas treningu generator stopniowo staje się lepszy w tworzeniu obrazów, które wyglądają realnie, podczas gdy dyskryminator staje się lepszy w ich rozróżnianiu. Proces osiąga równowagę, gdy dyskryminator nie jest już w stanie odróżnić prawdziwych obrazów od podróbek.
Ten notatnik przedstawia ten proces w zestawie danych MNIST. Poniższa animacja przedstawia serię obrazów wytworzonych przez generator , który był szkolony przez 50 epok. Obrazy zaczynają się od losowego szumu iz czasem coraz bardziej przypominają ręcznie pisane cyfry.
Aby dowiedzieć się więcej o GAN, zapoznaj się z kursem Wprowadzenie do głębokiego uczenia MIT.
Ustawiać
import tensorflow as tf
tf.__version__
'2.8.0-rc1'
# To generate GIFs
pip install imageio
pip install git+https://github.com/tensorflow/docs
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display
Załaduj i przygotuj zbiór danych
Użyjesz zestawu danych MNIST do szkolenia generatora i dyskryminatora. Generator wygeneruje odręczne cyfry przypominające dane MNIST.
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
Stwórz modele
Zarówno generator, jak i dyskryminator są definiowane za pomocą interfejsu API Keras Sequential .
Generator
Generator używa tf.keras.layers.Conv2DTranspose
(upsampling) do wytworzenia obrazu z nasiona (losowy szum). Zacznij od warstwy Dense
, która pobiera to ziarno jako dane wejściowe, a następnie kilkakrotnie próbkuj, aż osiągniesz żądany rozmiar obrazu 28x28x1. Zwróć uwagę na aktywację tf.keras.layers.LeakyReLU
dla każdej warstwy, z wyjątkiem warstwy wyjściowej, która używa tanh.
def make_generator_model():
model = tf.keras.Sequential()
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 1)
return model
Użyj (jeszcze nieprzeszkolonego) generatora, aby utworzyć obraz.
generator = make_generator_model()
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f6fe7a04b90>
Dyskryminator
Dyskryminatorem jest klasyfikator obrazów oparty na CNN.
def make_discriminator_model():
model = tf.keras.Sequential()
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
input_shape=[28, 28, 1]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
Użyj (jeszcze nieprzeszkolonego) rozróżniacza, aby sklasyfikować wygenerowane obrazy jako prawdziwe lub fałszywe. Model zostanie wytrenowany tak, aby wyświetlał wartości dodatnie dla prawdziwych obrazów i ujemne dla fałszywych obrazów.
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[-0.00339105]], shape=(1, 1), dtype=float32)
Zdefiniuj stratę i optymalizatory
Zdefiniuj funkcje strat i optymalizatory dla obu modeli.
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
Utrata dyskryminatora
Ta metoda określa ilościowo, jak dobrze dyskryminator jest w stanie odróżnić prawdziwe obrazy od podróbek. Porównuje przewidywania dyskryminatora na rzeczywistych obrazach z tablicą jedynek, a przewidywania dyskryminatora na fałszywych (wygenerowanych) obrazach z tablicą zer.
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
Strata generatora
Strata generatora określa ilościowo, jak dobrze był w stanie oszukać dyskryminator. Intuicyjnie, jeśli generator działa dobrze, dyskryminator zaklasyfikuje fałszywe obrazy jako prawdziwe (lub 1). Tutaj porównaj decyzje dyskryminacyjne na wygenerowanych obrazach z tablicą jedynek.
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
Dyskryminator i optymalizatory generatora są różne, ponieważ będziesz szkolić dwie sieci oddzielnie.
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
Zapisz punkty kontrolne
Ten notatnik pokazuje również, jak zapisywać i przywracać modele, co może być pomocne w przypadku przerwania długotrwałego zadania szkoleniowego.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
Zdefiniuj pętlę treningową
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16
# You will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])
Pętla treningowa zaczyna się od otrzymania przez generator losowego ziarna jako danych wejściowych. To ziarno służy do tworzenia obrazu. Dyskryminator jest następnie używany do klasyfikowania obrazów rzeczywistych (pobranych ze zbioru uczącego) oraz obrazów fałszywych (wytworzonych przez generator). Strata jest obliczana dla każdego z tych modeli, a gradienty są wykorzystywane do aktualizacji generatora i dyskryminatora.
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
# Produce images for the GIF as you go
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
seed)
# Save the model every 15 epochs
if (epoch + 1) % 15 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
# Generate after the final epoch
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
seed)
Generuj i zapisuj obrazy
def generate_and_save_images(model, epoch, test_input):
# Notice `training` is set to False.
# This is so all layers run in inference mode (batchnorm).
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
Trenuj modelkę
Wywołaj metodę train()
zdefiniowaną powyżej, aby jednocześnie trenować generator i dyskryminator. Uwaga, trenowanie GAN może być trudne. Ważne jest, aby generator i dyskryminator nie przeciążały się nawzajem (np. aby trenowały w podobnym tempie).
Na początku treningu generowane obrazy wyglądają jak losowy szum. W miarę postępu treningu generowane cyfry będą wyglądać coraz bardziej realistycznie. Po około 50 epokach przypominają cyfry MNIST. Przy domyślnych ustawieniach w Colab może to zająć około jednej minuty/epoki.
train(train_dataset, EPOCHS)
Przywróć najnowszy punkt kontrolny.
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f6ee8136950>
Utwórz GIF
# Display a single image using the epoch number
def display_image(epoch_no):
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)
Użyj imageio
, aby stworzyć animowany gif z obrazów zapisanych podczas treningu.
anim_file = 'dcgan.gif'
with imageio.get_writer(anim_file, mode='I') as writer:
filenames = glob.glob('image*.png')
filenames = sorted(filenames)
for filename in filenames:
image = imageio.imread(filename)
writer.append_data(image)
image = imageio.imread(filename)
writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)
Następne kroki
Ten samouczek pokazał kompletny kod niezbędny do napisania i trenowania GAN. W następnym kroku możesz poeksperymentować z innym zbiorem danych, na przykład zbiorem danych Atrybuty twarzy celebrytów na dużą skalę (CelebA) dostępnym na Kaggle . Aby dowiedzieć się więcej o sieciach GAN, zapoznaj się z samouczkiem NIPS 2016: Generative Adversarial Networks .