Couches probabilistes TFP : encodeur automatique variationnel

Dans cet exemple, nous montrons comment ajuster un autoencodeur variationnel à l'aide des "couches probabilistes" de TFP.

Dépendances et prérequis

Importer

import numpy as np

import tensorflow.compat.v2 as tf
tf
.enable_v2_behavior()

import tensorflow_datasets as tfds
import tensorflow_probability as tfp


tfk
= tf.keras
tfkl
= tf.keras.layers
tfpl
= tfp.layers
tfd
= tfp.distributions

Faites les choses rapidement !

Avant de plonger, assurons-nous que nous utilisons un GPU pour cette démo.

Pour ce faire, sélectionnez "Runtime" -> "Modifier le type d'exécution" -> "Accélérateur matériel" -> "GPU".

L'extrait suivant vérifiera que nous avons accès à un GPU.

if tf.test.gpu_device_name() != '/device:GPU:0':
 
print('WARNING: GPU device not found.')
else:
 
print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name()))
SUCCESS: Found GPU: /device:GPU:0

Charger l'ensemble de données

datasets, datasets_info = tfds.load(name='mnist',
                                    with_info
=True,
                                    as_supervised
=False)

def _preprocess(sample):
  image
= tf.cast(sample['image'], tf.float32) / 255.  # Scale to unit interval.
  image
= image < tf.random.uniform(tf.shape(image))   # Randomly binarize.
 
return image, image

train_dataset
= (datasets['train']
                 
.map(_preprocess)
                 
.batch(256)
                 
.prefetch(tf.data.AUTOTUNE)
                 
.shuffle(int(10e3)))
eval_dataset
= (datasets['test']
               
.map(_preprocess)
               
.batch(256)
               
.prefetch(tf.data.AUTOTUNE))

Notez que prétraiter () ci - dessus des rendements image, image plutôt que l' image parce que Keras est mis en place pour les modèles discriminants avec un ( par exemple, l' étiquette) format d'entrée, c. -à- \(p\theta(y|x)\). Étant donné que le but de la VAE est de récupérer les entrées x de x lui - même (c. -à- pθ(x|x)), la paire de données est ( par exemple, par exemple).

Code VAE Golf

Spécifiez le modèle.

input_shape = datasets_info.features['image'].shape
encoded_size
= 16
base_depth
= 32
prior = tfd.Independent(tfd.Normal(loc=tf.zeros(encoded_size), scale=1),
                        reinterpreted_batch_ndims
=1)
encoder = tfk.Sequential([
    tfkl
.InputLayer(input_shape=input_shape),
    tfkl
.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
    tfkl
.Conv2D(base_depth, 5, strides=1,
                padding
='same', activation=tf.nn.leaky_relu),
    tfkl
.Conv2D(base_depth, 5, strides=2,
                padding
='same', activation=tf.nn.leaky_relu),
    tfkl
.Conv2D(2 * base_depth, 5, strides=1,
                padding
='same', activation=tf.nn.leaky_relu),
    tfkl
.Conv2D(2 * base_depth, 5, strides=2,
                padding
='same', activation=tf.nn.leaky_relu),
    tfkl
.Conv2D(4 * encoded_size, 7, strides=1,
                padding
='valid', activation=tf.nn.leaky_relu),
    tfkl
.Flatten(),
    tfkl
.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size),
               activation
=None),
    tfpl
.MultivariateNormalTriL(
        encoded_size
,
        activity_regularizer
=tfpl.KLDivergenceRegularizer(prior)),
])
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py:158: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version.
Instructions for updating:
Do not pass `graph_parents`.  They will  no longer be used.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py:158: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version.
Instructions for updating:
Do not pass `graph_parents`.  They will  no longer be used.
decoder = tfk.Sequential([
    tfkl
.InputLayer(input_shape=[encoded_size]),
    tfkl
.Reshape([1, 1, encoded_size]),
    tfkl
.Conv2DTranspose(2 * base_depth, 7, strides=1,
                         padding
='valid', activation=tf.nn.leaky_relu),
    tfkl
.Conv2DTranspose(2 * base_depth, 5, strides=1,
                         padding
='same', activation=tf.nn.leaky_relu),
    tfkl
.Conv2DTranspose(2 * base_depth, 5, strides=2,
                         padding
='same', activation=tf.nn.leaky_relu),
    tfkl
.Conv2DTranspose(base_depth, 5, strides=1,
                         padding
='same', activation=tf.nn.leaky_relu),
    tfkl
.Conv2DTranspose(base_depth, 5, strides=2,
                         padding
='same', activation=tf.nn.leaky_relu),
    tfkl
.Conv2DTranspose(base_depth, 5, strides=1,
                         padding
='same', activation=tf.nn.leaky_relu),
    tfkl
.Conv2D(filters=1, kernel_size=5, strides=1,
                padding
='same', activation=None),
    tfkl
.Flatten(),
    tfpl
.IndependentBernoulli(input_shape, tfd.Bernoulli.logits),
])
vae = tfk.Model(inputs=encoder.inputs,
                outputs
=decoder(encoder.outputs[0]))

Faites de l'inférence.

negloglik = lambda x, rv_x: -rv_x.log_prob(x)

vae
.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),
            loss
=negloglik)

_
= vae.fit(train_dataset,
            epochs
=15,
            validation_data
=eval_dataset)
Epoch 1/15
235/235 [==============================] - 14s 61ms/step - loss: 206.5541 - val_loss: 163.1924
Epoch 2/15
235/235 [==============================] - 14s 59ms/step - loss: 151.1891 - val_loss: 143.6748
Epoch 3/15
235/235 [==============================] - 14s 58ms/step - loss: 141.3275 - val_loss: 137.9188
Epoch 4/15
235/235 [==============================] - 14s 58ms/step - loss: 136.7453 - val_loss: 133.2726
Epoch 5/15
235/235 [==============================] - 14s 58ms/step - loss: 132.3803 - val_loss: 131.8343
Epoch 6/15
235/235 [==============================] - 14s 58ms/step - loss: 129.2451 - val_loss: 127.1935
Epoch 7/15
235/235 [==============================] - 14s 59ms/step - loss: 126.0975 - val_loss: 123.6789
Epoch 8/15
235/235 [==============================] - 14s 58ms/step - loss: 124.0565 - val_loss: 122.5058
Epoch 9/15
235/235 [==============================] - 14s 58ms/step - loss: 122.9974 - val_loss: 121.9544
Epoch 10/15
235/235 [==============================] - 14s 58ms/step - loss: 121.7349 - val_loss: 120.8735
Epoch 11/15
235/235 [==============================] - 14s 58ms/step - loss: 121.0856 - val_loss: 120.1340
Epoch 12/15
235/235 [==============================] - 14s 58ms/step - loss: 120.2232 - val_loss: 121.3554
Epoch 13/15
235/235 [==============================] - 14s 58ms/step - loss: 119.8123 - val_loss: 119.2351
Epoch 14/15
235/235 [==============================] - 14s 58ms/step - loss: 119.2685 - val_loss: 118.2133
Epoch 15/15
235/235 [==============================] - 14s 59ms/step - loss: 118.8895 - val_loss: 119.4771

Regarde maman, non Mains Tenseurs !

# We'll just examine ten random digits.
x
= next(iter(eval_dataset))[0][:10]
xhat
= vae(x)
assert isinstance(xhat, tfd.Distribution)

Utilitaire de tracé d'image

import matplotlib.pyplot as plt

def display_imgs(x, y=None):
 
if not isinstance(x, (np.ndarray, np.generic)):
    x
= np.array(x)
  plt
.ioff()
  n
= x.shape[0]
  fig
, axs = plt.subplots(1, n, figsize=(n, 1))
 
if y is not None:
    fig
.suptitle(np.argmax(y, axis=1))
 
for i in range(n):
    axs
.flat[i].imshow(x[i].squeeze(), interpolation='none', cmap='gray')
    axs
.flat[i].axis('off')
  plt
.show()
  plt
.close()
  plt
.ion()

print('Originals:')
display_imgs
(x)

print('Decoded Random Samples:')
display_imgs
(xhat.sample())

print('Decoded Modes:')
display_imgs
(xhat.mode())

print('Decoded Means:')
display_imgs
(xhat.mean())
Originals:

png

Decoded Random Samples:

png

Decoded Modes:

png

Decoded Means:

png

# Now, let's generate ten never-before-seen digits.
z
= prior.sample(10)
xtilde
= decoder(z)
assert isinstance(xtilde, tfd.Distribution)
print('Randomly Generated Samples:')
display_imgs
(xtilde.sample())

print('Randomly Generated Modes:')
display_imgs
(xtilde.mode())

print('Randomly Generated Means:')
display_imgs
(xtilde.mean())
Randomly Generated Samples:

png

Randomly Generated Modes:

png

Randomly Generated Means:

png