tfp.experimental.nn.ConvolutionVariationalReparameterization

Convolution layer class with reparameterization estimator.

Inherits From: VariationalLayer, Layer

This layer implements the Bayesian variational inference analogue to a Convolution layer by assuming the kernel and/or the bias are drawn from distributions. By default, the layer implements a stochastic forward pass via sampling from the kernel and bias posteriors,

kernel, bias ~ posterior
outputs = matmul(inputs, kernel) + bias

It uses the reparameterization estimator [(Kingma and Welling, 2014)][1], which performs a Monte Carlo approximation of the distribution integrating over the kernel and bias.

The arguments permit separate specification of the surrogate posterior (q(W|x)), prior (p(W)), and divergence for both the kernel and bias distributions.

Upon being built, this layer adds losses (accessible via the losses property) representing the divergences of kernel and/or bias surrogate posteriors and their respective priors. When doing minibatch stochastic optimization, make sure to scale this loss such that it is applied just once per epoch (e.g. if kl is the sum of losses for each element of the batch, you should pass kl / num_examples_per_epoch to your optimizer).

You can access the kernel and/or bias posterior and prior distributions after the layer is built via the kernel_posterior, kernel_prior, bias_posterior and bias_prior properties.

Examples

We illustrate a Bayesian neural network with variational inference, assuming a dataset of images and length-10 one-hot targets.

import functools
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
import tensorflow_datasets as tfds
tfb = tfp.bijectors
tfd = tfp.distributions
tfn = tfp.experimental.nn

# 1  Prepare Dataset

[train_dataset, eval_dataset], datasets_info = tfds.load(
    name='mnist',
    split=['train', 'test'],
    with_info=True,
    as_supervised=True,
    shuffle_files=True)
def _preprocess(image, label):
  # image = image < tf.random.uniform(tf.shape(image))   # Randomly binarize.
  image = tf.cast(image, tf.float32) / 255.  # Scale to unit interval.
  lo = 0.001
  image = (1. - 2. * lo) * image + lo  # Rescale to *open* unit interval.
  return image, label
batch_size = 32
train_size = datasets_info.splits['train'].num_examples
train_dataset = tfn.util.tune_dataset(
    train_dataset,
    batch_shape=(batch_size,),
    shuffle_size=int(train_size / 7),
    preprocess_fn=_preprocess)
train_iter = iter(train_dataset)
eval_iter = iter(eval_dataset)
x, y = next(train_iter)
evidence_shape = x.shape[1:]
targets_shape = y.shape[1:]

# 2  Specify Model

n = tf.cast(train_size, tf.float32)

BayesConv2D = functools.partial(
    tfn.ConvolutionVariationalReparameterization,
    rank=2,
    padding='same',
    filter_shape=5,
    # Use `he_uniform` because we'll use the `relu` family.
    kernel_initializer=tf.initializers.he_uniform(),
    penalty_weight=1. / n)

BayesAffine = functools.partial(
    tfn.AffineVariationalReparameterization,
    penalty_weight=1. / n)

scale = tfp.util.TransformedVariable(1., tfb.Softplus())
bnn = tfn.Sequential([
    BayesConv2D(evidence_shape[-1], 32, filter_shape=7, strides=2,
                activation_fn=tf.nn.leaky_relu),           # [b, 14, 14, 32]
    tfn.util.flatten_rightmost(ndims=3),                   # [b, 14 * 14 * 32]
    BayesAffine(14 * 14 * 32, np.prod(target_shape) - 1),  # [b, 9]
    lambda loc: tfb.SoftmaxCentered()(
        tfd.Independent(tfd.Normal(loc, scale),
                        reinterpreted_batch_ndims=1)),  # [b, 10]
], name='bayesian_neural_network')

print(bnn.summary())

# 3  Train.

def loss_fn():
  x, y = next(train_iter)
  nll = -tf.reduce_mean(bnn(x).log_prob(y), axis=-1)
  kl = bnn.extra_loss  # Already normalized via `penalty_weight` arg.
  loss = nll + kl
  return loss, (nll, kl)
opt = tf.optimizers.Adam()
fit_op = tfn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables)
for _ in range(200):
  loss, (nll, kl), g = fit_op()

This example uses reparameterization gradients to minimize the Kullback-Leibler divergence up to a constant, also known as the negative Evidence Lower Bound. It consists of the sum of two terms: the expected negative log-likelihood, which we approximate via Monte Carlo; and the KL divergence, which is added via regularizer terms which are arguments to the layer.

References

[1]: Diederik Kingma and Max Welling. Auto-Encoding Variational Bayes. In International Conference on Learning Representations, 2014. https://arxiv.org/abs/1312.6114

input_size ... In Keras, this argument is inferred from the rightmost input shape, i.e., tf.shape(inputs)[-1]. This argument specifies the size of the second from the rightmost dimension of both inputs and kernel. Default value: None.
output_size ... In Keras, this argument is called filters. This argument specifies the rightmost dimension size of both kernel and bias.
filter_shape ... In Keras, this argument is called kernel_size. This argument specifies the leftmost rank dimensions' sizes of kernel.
rank An integer, the rank of the convolution, e.g. "2" for 2D convolution. This argument implies the number of kernel dimensions, i.e., kernel.shape.rank == rank + 2. In Keras, this argument has the same name and semantics. Default value: 2.
strides An integer or tuple/list of n integers, specifying the stride length of the convolution. In Keras, this argument has the same name and semantics. Default value: 1.
padding One of "VALID" or "SAME" (case-insensitive). In Keras, this argument has the same name and semantics (except we don't support "CAUSAL"). Default value: 'VALID'.
dilations An integer or tuple/list of rank integers, specifying the dilation rate to use for dilated convolution. Currently, specifying any dilations value != 1 is incompatible with specifying any strides value != 1. In Keras, this argument is called dilation_rate. Default value: 1.
kernel_initializer ... Default value: None (i.e., tfp.experimental.nn.initializers.glorot_uniform()).
bias_initializer ... Default value: None (i.e., tf.initializers.zeros()).
make_posterior_fn ... Default value: tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag.
make_prior_fn ... Default value: tfp.experimental.nn.util.make_kernel_bias_prior_spike_and_slab.
posterior_value_fn ... Default valye: tfd.Distribution.sample
unpack_weights_fn Default value: unpack_kernel_and_bias
dtype ... Default value: tf.float32.
activation_fn ... Default value: None.
seed ... Default value: None (i.e., no seed).
validate_args ...
name ... Default value: None (i.e., 'ConvolutionVariationalReparameterization').

activation_fn

also_track

dtype

name Returns the name of this module as passed or determined in the ctor.

name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.
posterior

posterior_value

posterior_value_fn

prior

submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
True
list(b.submodules) == [c]
True
list(c.submodules) == []
True

trainable_variables Sequence of trainable variables owned by this module and its submodules.

unpack_weights_fn

validate_args Python bool indicating possibly expensive checks are enabled.
variables Sequence of variables owned by this module and its submodules.

Methods

load

View source

save

View source

summary

View source

with_name_scope

Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  @tf.Module.with_name_scope
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

Args
method The method to wrap.

Returns
The original method wrapped such that it enters the module's name scope.

__call__

View source

Call self as a function.