Lihat di TensorFlow.org | Jalankan di Google Colab | Lihat di GitHub | Unduh buku catatan | Lihat model TF Hub |
Notebook ini merupakan demo untuk model BigBiGAN tersedia di TF Hub .
BigBiGAN meluas standar (Big) Gans dengan menambahkan modul encoder yang dapat digunakan untuk diawasi belajar representasi. Secara kasar, membalikkan encoder generator dengan memprediksi latents z
diberikan data real x
. Lihat kertas BigBiGAN di arXiv [1] untuk informasi lebih lanjut tentang model-model.
Setelah terhubung ke runtime, mulailah dengan mengikuti petunjuk berikut:
- (Opsional) Update yang dipilih
module_path
dalam kode sel pertama di bawah untuk memuat generator BigBiGAN untuk arsitektur encoder yang berbeda. - Klik Runtime> Jalankan semua untuk menjalankan setiap sel dalam rangka. Setelah itu, output, termasuk visualisasi sampel dan rekonstruksi BigBiGAN, akan muncul secara otomatis di bawah.
[1] Jeff Donahue dan Karen Simonyan. Skala Besar Adversarial Representasi Belajar . arXiv: 1907,02544, 2019.
Pertama, atur jalur modul. Secara default, kami memuat model BigBiGAN dengan lebih kecil encoder berbasis ResNet-50 dari <a href="https://tfhub.dev/deepmind/bigbigan-resnet50/1">https://tfhub.dev/deepmind/bigbigan-resnet50/1</a>
. Untuk memuat lebih besar RevNet-50-x4 model berbasis digunakan untuk mencapai hasil terbaik pembelajaran representasi, komentar keluar aktif module_path
pengaturan dan tanda komentar yang lain.
module_path = 'https://tfhub.dev/deepmind/bigbigan-resnet50/1' # ResNet-50
# module_path = 'https://tfhub.dev/deepmind/bigbigan-revnet50x4/1' # RevNet-50 x4
Mempersiapkan
import io
import IPython.display
import PIL.Image
from pprint import pformat
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import tensorflow_hub as hub
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/compat/v2_compat.py:111: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version. Instructions for updating: non-resource variables are not supported in the long term
Tentukan beberapa fungsi untuk menampilkan gambar
def imgrid(imarray, cols=4, pad=1, padval=255, row_major=True):
"""Lays out a [N, H, W, C] image array as a single image grid."""
pad = int(pad)
if pad < 0:
raise ValueError('pad must be non-negative')
cols = int(cols)
assert cols >= 1
N, H, W, C = imarray.shape
rows = N // cols + int(N % cols != 0)
batch_pad = rows * cols - N
assert batch_pad >= 0
post_pad = [batch_pad, pad, pad, 0]
pad_arg = [[0, p] for p in post_pad]
imarray = np.pad(imarray, pad_arg, 'constant', constant_values=padval)
H += pad
W += pad
grid = (imarray
.reshape(rows, cols, H, W, C)
.transpose(0, 2, 1, 3, 4)
.reshape(rows*H, cols*W, C))
if pad:
grid = grid[:-pad, :-pad]
return grid
def interleave(*args):
"""Interleaves input arrays of the same shape along the batch axis."""
if not args:
raise ValueError('At least one argument is required.')
a0 = args[0]
if any(a.shape != a0.shape for a in args):
raise ValueError('All inputs must have the same shape.')
if not a0.shape:
raise ValueError('Inputs must have at least one axis.')
out = np.transpose(args, [1, 0] + list(range(2, len(a0.shape) + 1)))
out = out.reshape(-1, *a0.shape[1:])
return out
def imshow(a, format='png', jpeg_fallback=True):
"""Displays an image in the given format."""
a = a.astype(np.uint8)
data = io.BytesIO()
PIL.Image.fromarray(a).save(data, format)
im_data = data.getvalue()
try:
disp = IPython.display.display(IPython.display.Image(im_data))
except IOError:
if jpeg_fallback and format != 'jpeg':
print ('Warning: image was too large to display in format "{}"; '
'trying jpeg instead.').format(format)
return imshow(a, format='jpeg')
else:
raise
return disp
def image_to_uint8(x):
"""Converts [-1, 1] float array to [0, 255] uint8."""
x = np.asarray(x)
x = (256. / 2.) * (x + 1.)
x = np.clip(x, 0, 255)
x = x.astype(np.uint8)
return x
Muat modul BigBiGAN TF Hub dan tampilkan fungsionalitasnya yang tersedia
# module = hub.Module(module_path, trainable=True, tags={'train'}) # training
module = hub.Module(module_path) # inference
for signature in module.get_signature_names():
print('Signature:', signature)
print('Inputs:', pformat(module.get_input_info_dict(signature)))
print('Outputs:', pformat(module.get_output_info_dict(signature)))
print()
Signature: default Inputs: {'x': <hub.ParsedTensorInfo shape=(?, 256, 256, 3) dtype=float32 is_sparse=False>} Outputs: {'default': <hub.ParsedTensorInfo shape=(?, 120) dtype=float32 is_sparse=False>} Signature: generate Inputs: {'z': <hub.ParsedTensorInfo shape=(?, 120) dtype=float32 is_sparse=False>} Outputs: {'default': <hub.ParsedTensorInfo shape=(?, 128, 128, 3) dtype=float32 is_sparse=False>, 'upsampled': <hub.ParsedTensorInfo shape=(?, 256, 256, 3) dtype=float32 is_sparse=False>} Signature: discriminate Inputs: {'x': <hub.ParsedTensorInfo shape=(?, 128, 128, 3) dtype=float32 is_sparse=False>, 'z': <hub.ParsedTensorInfo shape=(?, 120) dtype=float32 is_sparse=False>} Outputs: {'score_x': <hub.ParsedTensorInfo shape=(?,) dtype=float32 is_sparse=False>, 'score_xz': <hub.ParsedTensorInfo shape=(?,) dtype=float32 is_sparse=False>, 'score_z': <hub.ParsedTensorInfo shape=(?,) dtype=float32 is_sparse=False>} Signature: encode Inputs: {'x': <hub.ParsedTensorInfo shape=(?, 256, 256, 3) dtype=float32 is_sparse=False>} Outputs: {'avepool_feat': <hub.ParsedTensorInfo shape=(?, 2048) dtype=float32 is_sparse=False>, 'bn_crelu_feat': <hub.ParsedTensorInfo shape=(?, 4096) dtype=float32 is_sparse=False>, 'default': <hub.ParsedTensorInfo shape=(?, 120) dtype=float32 is_sparse=False>, 'z_mean': <hub.ParsedTensorInfo shape=(?, 120) dtype=float32 is_sparse=False>, 'z_sample': <hub.ParsedTensorInfo shape=(?, 120) dtype=float32 is_sparse=False>, 'z_stdev': <hub.ParsedTensorInfo shape=(?, 120) dtype=float32 is_sparse=False>}
Tentukan kelas pembungkus untuk akses mudah ke berbagai fungsi
class BigBiGAN(object):
def __init__(self, module):
"""Initialize a BigBiGAN from the given TF Hub module."""
self._module = module
def generate(self, z, upsample=False):
"""Run a batch of latents z through the generator to generate images.
Args:
z: A batch of 120D Gaussian latents, shape [N, 120].
Returns: a batch of generated RGB images, shape [N, 128, 128, 3], range
[-1, 1].
"""
outputs = self._module(z, signature='generate', as_dict=True)
return outputs['upsampled' if upsample else 'default']
def make_generator_ph(self):
"""Creates a tf.placeholder with the dtype & shape of generator inputs."""
info = self._module.get_input_info_dict('generate')['z']
return tf.placeholder(dtype=info.dtype, shape=info.get_shape())
def gen_pairs_for_disc(self, z):
"""Compute generator input pairs (G(z), z) for discriminator, given z.
Args:
z: A batch of latents (120D standard Gaussians), shape [N, 120].
Returns: a tuple (G(z), z) of discriminator inputs.
"""
# Downsample 256x256 image x for 128x128 discriminator input.
x = self.generate(z)
return x, z
def encode(self, x, return_all_features=False):
"""Run a batch of images x through the encoder.
Args:
x: A batch of data (256x256 RGB images), shape [N, 256, 256, 3], range
[-1, 1].
return_all_features: If True, return all features computed by the encoder.
Otherwise (default) just return a sample z_hat.
Returns: the sample z_hat of shape [N, 120] (or a dict of all features if
return_all_features).
"""
outputs = self._module(x, signature='encode', as_dict=True)
return outputs if return_all_features else outputs['z_sample']
def make_encoder_ph(self):
"""Creates a tf.placeholder with the dtype & shape of encoder inputs."""
info = self._module.get_input_info_dict('encode')['x']
return tf.placeholder(dtype=info.dtype, shape=info.get_shape())
def enc_pairs_for_disc(self, x):
"""Compute encoder input pairs (x, E(x)) for discriminator, given x.
Args:
x: A batch of data (256x256 RGB images), shape [N, 256, 256, 3], range
[-1, 1].
Returns: a tuple (downsample(x), E(x)) of discriminator inputs.
"""
# Downsample 256x256 image x for 128x128 discriminator input.
x_down = tf.nn.avg_pool(x, ksize=2, strides=2, padding='SAME')
z = self.encode(x)
return x_down, z
def discriminate(self, x, z):
"""Compute the discriminator scores for pairs of data (x, z).
(x, z) must be batches with the same leading batch dimension, and joint
scores are computed on corresponding pairs x[i] and z[i].
Args:
x: A batch of data (128x128 RGB images), shape [N, 128, 128, 3], range
[-1, 1].
z: A batch of latents (120D standard Gaussians), shape [N, 120].
Returns:
A dict of scores:
score_xz: the joint scores for the (x, z) pairs.
score_x: the unary scores for x only.
score_z: the unary scores for z only.
"""
inputs = dict(x=x, z=z)
return self._module(inputs, signature='discriminate', as_dict=True)
def reconstruct_x(self, x, use_sample=True, upsample=False):
"""Compute BigBiGAN reconstructions of images x via G(E(x)).
Args:
x: A batch of data (256x256 RGB images), shape [N, 256, 256, 3], range
[-1, 1].
use_sample: takes a sample z_hat ~ E(x). Otherwise, deterministically
use the mean. (Though a sample z_hat may be far from the mean z,
typically the resulting recons G(z_hat) and G(z) are very
similar.
upsample: if set, upsample the reconstruction to the input resolution
(256x256). Otherwise return the raw lower resolution generator output
(128x128).
Returns: a batch of recons G(E(x)), shape [N, 256, 256, 3] if
`upsample`, otherwise [N, 128, 128, 3].
"""
if use_sample:
z = self.encode(x)
else:
z = self.encode(x, return_all_features=True)['z_mean']
recons = self.generate(z, upsample=upsample)
return recons
def losses(self, x, z):
"""Compute per-module BigBiGAN losses given data & latent sample batches.
Args:
x: A batch of data (256x256 RGB images), shape [N, 256, 256, 3], range
[-1, 1].
z: A batch of latents (120D standard Gaussians), shape [M, 120].
For the original BigBiGAN losses, pass batches of size N=M=2048, with z's
sampled from a 120D standard Gaussian (e.g., np.random.randn(2048, 120)),
and x's sampled from the ImageNet (ILSVRC2012) training set with the
"ResNet-style" preprocessing from:
https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_preprocessing.py
Returns:
A dict of per-module losses:
disc: loss for the discriminator.
enc: loss for the encoder.
gen: loss for the generator.
"""
# Compute discriminator scores on (x, E(x)) pairs.
# Downsample 256x256 image x for 128x128 discriminator input.
scores_enc_x_dict = self.discriminate(*self.enc_pairs_for_disc(x))
scores_enc_x = tf.concat([scores_enc_x_dict['score_xz'],
scores_enc_x_dict['score_x'],
scores_enc_x_dict['score_z']], axis=0)
# Compute discriminator scores on (G(z), z) pairs.
scores_gen_z_dict = self.discriminate(*self.gen_pairs_for_disc(z))
scores_gen_z = tf.concat([scores_gen_z_dict['score_xz'],
scores_gen_z_dict['score_x'],
scores_gen_z_dict['score_z']], axis=0)
disc_loss_enc_x = tf.reduce_mean(tf.nn.relu(1. - scores_enc_x))
disc_loss_gen_z = tf.reduce_mean(tf.nn.relu(1. + scores_gen_z))
disc_loss = disc_loss_enc_x + disc_loss_gen_z
enc_loss = tf.reduce_mean(scores_enc_x)
gen_loss = tf.reduce_mean(-scores_gen_z)
return dict(disc=disc_loss, enc=enc_loss, gen=gen_loss)
Buat tensor yang akan digunakan nanti untuk menghitung sampel, rekonstruksi, skor diskriminator, dan kerugian
bigbigan = BigBiGAN(module)
# Make input placeholders for x (`enc_ph`) and z (`gen_ph`).
enc_ph = bigbigan.make_encoder_ph()
gen_ph = bigbigan.make_generator_ph()
# Compute samples G(z) from encoder input z (`gen_ph`).
gen_samples = bigbigan.generate(gen_ph)
# Compute reconstructions G(E(x)) of encoder input x (`enc_ph`).
recon_x = bigbigan.reconstruct_x(enc_ph, upsample=True)
# Compute encoder features used for representation learning evaluations given
# encoder input x (`enc_ph`).
enc_features = bigbigan.encode(enc_ph, return_all_features=True)
# Compute discriminator scores for encoder pairs (x, E(x)) given x (`enc_ph`)
# and generator pairs (G(z), z) given z (`gen_ph`).
disc_scores_enc = bigbigan.discriminate(*bigbigan.enc_pairs_for_disc(enc_ph))
disc_scores_gen = bigbigan.discriminate(*bigbigan.gen_pairs_for_disc(gen_ph))
# Compute losses.
losses = bigbigan.losses(enc_ph, gen_ph)
INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore
Buat sesi TensorFlow dan inisialisasi variabel
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
sampel generator
Pertama, kita akan memvisualisasikan sampel dari BigBiGAN generator pretrained sampling pembangkit masukan z
dari Gaussian standar (via np.random.randn
) dan menampilkan gambar yang dihasilkannya. Sejauh ini kami tidak melampaui kemampuan GAN standar -- kami hanya menggunakan generator (dan mengabaikan encoder) untuk saat ini.
feed_dict = {gen_ph: np.random.randn(32, 120)}
_out_samples = sess.run(gen_samples, feed_dict=feed_dict)
print('samples shape:', _out_samples.shape)
imshow(imgrid(image_to_uint8(_out_samples), cols=4))
samples shape: (32, 128, 128, 3)
Beban test_images
dari dataset TF-Bunga
BigBiGAN dilatih di ImageNet, tetapi karena terlalu besar untuk digunakan dalam demo ini, kami menggunakan kumpulan data TF-Flowers [1] yang lebih kecil sebagai masukan untuk memvisualisasikan rekonstruksi dan menghitung fitur encoder.
Dalam sel ini kami memuat TF-Bunga (download dataset jika diperlukan) dan menyimpan batch tetap 256x256 sampel RGB gambar dalam NumPy array yang test_images
.
[1] https://www.tensorflow.org/datasets/catalog/tf_flowers
def get_flowers_data():
"""Returns a [32, 256, 256, 3] np.array of preprocessed TF-Flowers samples."""
import tensorflow_datasets as tfds
ds, info = tfds.load('tf_flowers', split='train', with_info=True)
# Just get the images themselves as we don't need labels for this demo.
ds = ds.map(lambda x: x['image'])
# Filter out small images (with minor edge length <256).
ds = ds.filter(lambda x: tf.reduce_min(tf.shape(x)[:2]) >= 256)
# Take the center square crop of the image and resize to 256x256.
def crop_and_resize(image):
imsize = tf.shape(image)[:2]
minor_edge = tf.reduce_min(imsize)
start = (imsize - minor_edge) // 2
stop = start + minor_edge
cropped_image = image[start[0] : stop[0], start[1] : stop[1]]
resized_image = tf.image.resize_bicubic([cropped_image], [256, 256])[0]
return resized_image
ds = ds.map(crop_and_resize)
# Convert images from [0, 255] uint8 to [-1, 1] float32.
ds = ds.map(lambda image: tf.cast(image, tf.float32) / (255. / 2.) - 1)
# Take the first 32 samples.
ds = ds.take(32)
return np.array(list(tfds.as_numpy(ds)))
test_images = get_flowers_data()
2021-11-05 12:42:36.340550: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Rekonstruksi
Sekarang kita memvisualisasikan BigBiGAN rekonstruksi dengan melewati gambar nyata melalui encoder dan kembali melalui generator, komputasi G(E(x))
diberikan gambar x
. Di bawah, gambar masukan x
ditunjukkan di kolom kiri, dan rekonstruksi sesuai ditampilkan di sebelah kanan.
Perhatikan bahwa rekonstruksi tidak cocok dengan piksel yang sempurna dengan gambar input; alih-alih, mereka cenderung menangkap konten semantik tingkat yang lebih tinggi dari input sambil "melupakan" sebagian besar detail tingkat rendah. Ini menunjukkan bahwa pembuat enkode BigBiGAN dapat belajar menangkap jenis informasi semantik tingkat tinggi tentang gambar yang ingin kita lihat dalam pendekatan pembelajaran representasi.
Perhatikan juga bahwa rekonstruksi mentah dari gambar masukan 256x256 berada pada resolusi lebih rendah yang dihasilkan oleh generator kami -- 128x128. Kami melakukan upsample untuk tujuan visualisasi.
test_images_batch = test_images[:16]
_out_recons = sess.run(recon_x, feed_dict={enc_ph: test_images_batch})
print('reconstructions shape:', _out_recons.shape)
inputs_and_recons = interleave(test_images_batch, _out_recons)
print('inputs_and_recons shape:', inputs_and_recons.shape)
imshow(imgrid(image_to_uint8(inputs_and_recons), cols=2))
reconstructions shape: (16, 256, 256, 3) inputs_and_recons shape: (32, 256, 256, 3)
Fitur pembuat enkode
Kami sekarang mendemonstrasikan cara menghitung fitur dari encoder yang digunakan untuk evaluasi pembelajaran representasi standar.
Fitur-fitur ini dapat digunakan dalam pengklasifikasi berbasis tetangga linier atau terdekat. Kami menyertakan fitur standar diambil setelah rata-rata global pooling (key avepool_feat
) serta yang lebih besar "BN + CReLU" fitur (key bn_crelu_feat
) digunakan untuk mencapai hasil terbaik.
_out_features = sess.run(enc_features, feed_dict={enc_ph: test_images_batch})
print('AvePool features shape:', _out_features['avepool_feat'].shape)
print('BN+CReLU features shape:', _out_features['bn_crelu_feat'].shape)
AvePool features shape: (16, 2048) BN+CReLU features shape: (16, 4096)
Skor dan kerugian diskriminator
Terakhir, kami akan menghitung skor dan kerugian diskriminator pada kumpulan pasangan encoder dan generator. Kerugian ini dapat diteruskan ke pengoptimal untuk melatih BigBiGAN.
Kami menggunakan batch yang kami gambar di atas sebagai input encoder x
, menghitung skor encoder sebagai D(x, E(x))
. Untuk input pembangkit kami sampel z
dari 120D standar Gaussian melalui np.random.randn
, menghitung skor Generator sebagai D(G(z), z)
.
Discriminator memprediksi skor bersama score_xz
untuk (x, z)
pasangan serta skor unary score_x
dan score_z
untuk x
dan z
saja, masing-masing. Ini dilatih untuk memberikan skor tinggi (positif) untuk pasangan pembuat enkode dan skor rendah (negatif) untuk pasangan generator. Ini sebagian besar memegang bawah, meskipun unary score_z
negatif dalam kedua kasus, menunjukkan bahwa encoder output E(x)
menyerupai sampel yang sebenarnya dari sebuah Gaussian.
feed_dict = {enc_ph: test_images, gen_ph: np.random.randn(32, 120)}
_out_scores_enc, _out_scores_gen, _out_losses = sess.run(
[disc_scores_enc, disc_scores_gen, losses], feed_dict=feed_dict)
print('Encoder scores:', {k: v.mean() for k, v in _out_scores_enc.items()})
print('Generator scores:', {k: v.mean() for k, v in _out_scores_gen.items()})
print('Losses:', _out_losses)
Encoder scores: {'score_xz': 0.6921617, 'score_z': -0.50248873, 'score_x': 1.4621685} Generator scores: {'score_xz': -0.8883822, 'score_z': -0.45992172, 'score_x': -0.5907474} Losses: {'disc': 1.2274433, 'enc': 0.55200976, 'gen': 0.64635044}