![]() |
Implements the Glow Bijector from Kingma & Dhariwal (2018)[1].
Inherits From: Chain
, Composition
, Bijector
tfp.bijectors.Glow(
output_shape=(32, 32, 3), num_glow_blocks=3, num_steps_per_block=32,
coupling_bijector_fn=None, exit_bijector_fn=None, grab_after_block=None,
use_actnorm=True, seed=None, validate_args=False, name='glow'
)
Overview: Glow
is a chain of bijectors which transforms a rank-1 tensor
(vector) into a rank-3 tensor (e.g. an RGB image). Glow
does this by
chaining together an alternating series of "Blocks," "Squeezes," and "Exits"
which are each themselves special chains of other bijectors. The intended use
of Glow
is as part of a tfp.distributions.TransformedDistribution
, in
which the base distribution over the vector space is used to generate samples
in the image space. In the paper, an Independent Normal distribution is used
as the base distribution.
A "Block" (implemented as the GlowBlock
Bijector) performs much of the
transformations which allow glow to produce sophisticated and complex mappings
between the image space and the latent space and therefore achieve rich image
generation performance. A Block is composed of num_steps_per_block
steps,
which are each implemented as a Chain
containing an
ActivationNormalization
(ActNorm) bijector, followed by an (invertible)
OneByOneConv
bijector, and finally a coupling bijector. The coupling
bijector is an instance of a RealNVP
bijector, and uses the
coupling_bijector_fn
function to instantiate the coupling bijector function
which is given to the RealNVP
. This function returns a bijector which
defines the coupling (e.g. Shift(Scale)
for affine coupling or Shift
for
additive coupling).
A "Squeeze" converts spatial features into channel features. It is
implemented using the Expand
bijector. The difference in names is
due to the fact that the forward
function from glow is meant to ultimately
correspond to sampling from a tfp.util.TransformedDistribution
object,
which would use Expand
(Squeeze is just Invert(Expand)). The Expand
bijector takes a tensor with shape [H, W, C]
and returns a tensor with shape
[2H, 2W, C / 4]
, such that each 2x2x1 spatial tile in the output is composed
from a single 1x1x4 tile in the input tensor, as depicted in the figure below.
Forward pass (Expand)
______ __________
\ \ \ \ \
\\ \ ----> \ 1 \ 2 \
\\\__1__\ \____\____\
\\\__2__\ \ \ \
\\__3__\ <---- \ 3 \ 4 \
\__4__\ \____\____\
Inverse pass (Squeeze)
This is implemented using a chain of Reshape
-> Transpose
-> Reshape
bijectors. Note that on an inverse pass through the bijector, each Squeeze
will cause the width/height of the image to decrease by a factor of 2.
Therefore, the input image must be evenly divisible by 2 at least
num_glow_blocks
times, since it will pass through a Squeeze step that many
times.
An "Exit" is simply a junction at which some of the tensor "exits" from the
glow bijector and therefore avoids any further alteration. Each exit is
implemented as a Blockwise
bijector, where some channels are given to the
rest of the glow model, and the rest are given to a bypass implemented using
the Identity
bijector. The fraction of channels to be removed at each exit
is determined by the grab_after_block
arg, indicates the fraction of
remaining channels which join the identity bypass. The fraction is
converted to an integer number of channels by multiplying by the remaining
number of channels and rounding.
Additionally, at each exit, glow couples the tensor exiting the highway to
the tensor continuing onward. This makes small scale features in the image
dependent on larger scale features, since the larger scale features dictate
the mean and scale of the distribution over the smaller scale features.
This coupling is done similarly to the Coupling bijector in each step of the
flow (i.e. using a RealNVP bijector). However for the exit bijector, the
coupling is instantiated using exit_bijector_fn
rather than coupling
bijector fn, allowing for different behaviors between standard coupling and
exit coupling. Also note that because the exit utilizes a coupling bijector,
there are two special cases (all channels exiting and no channels exiting).
The full Glow bijector consists of num_glow_blocks
Blocks each of which
contains num_steps_per_block
steps. Each step implements a coupling using
bijector_coupling_fn
. Between blocks, glow converts between spatial pixels
and channels using the Expand Bijector, and splits channels out of the
bijector using the Exit Bijector. The channels which have exited continue
onward through Identity bijectors and those which have not exited are given
to the next block. After passing through all Blocks, the tensor is reshaped
to a rank-1 tensor with the same number of elements. This is where the
distribution will be defined.
A schematic diagram of Glow is shown below. The forward
function of the
bijector starts from the bottom and goes upward, while the inverse
function
starts from the top and proceeds downward.
==============================================================================
Glow Schematic Diagram
Input Image ######################## shape = [H, W, C]
\ /<- Expand Bijector turns spatial
\ / dimensions into channels.
_
| XXXXXXXXXXXXXXXXXXXX
| XXXXXXXXXXXXXXXXXXXX
| XXXXXXXXXXXXXXXXXXXX A single step of the flow consists
Glow Block - | XXXXXXXXXXXXXXXXXXXX <- of ActNorm -> 1x1Conv -> Coupling.
| XXXXXXXXXXXXXXXXXXXX there are num_steps_per_block
| XXXXXXXXXXXXXXXXXXXX steps of the flow in each block.
|_ XXXXXXXXXXXXXXXXXXXX
\ / <-- Expand bijectors follow each glow
\ / block
XXXXXXXX\\\\\\\\ <-- Exit Bijector removes channels
_ _ from additional alteration.
| XXXXXXXX ! | !
| XXXXXXXX ! | !
| XXXXXXXX ! | ! After exiting, channels are passed
Glow Block - | XXXXXXXX ! | ! <--- downward using the Blockwise and
| XXXXXXXX ! | ! Identify bijectors.
| XXXXXXXX ! | !
|_ XXXXXXXX ! | !
\ / <---- Expand Bijector
\ /
XXX\\\ | ! <---- Exit Bijector
_
| XXX ! | | !
| XXX ! | | !
| XXX ! | | !
Glow Block - | XXX ! | | !
| XXX ! | | !
| XXX ! | | !
|_ XXX ! | | !
XX\ ! | | ! <----- (Optional) Exit Bijector
| | |
v v v
Output Distribution ########## shape = [H * W * C]
_________________________
| Legend |
| XX = Step of flow |
| X\ = Exit bijector |
| \/ = Expand bijector |
| !|! = Identity bijector |
| |
| up = Forward pass |
| dn = Inverse pass |
|_________________________|
==============================================================================
The default configuration for glow is meant to replicate the architecture in [1] for generating images from CIFAR-10.
Example usage:
from functools import reduce
from operator import mul
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
tfb = tfp.bijectors
tfd = tfp.distributions
data, info = tfds.load('cifar10', with_info=True)
train_data, test_data = data['train'], data['test']
preprocess = lambda x: tf.cast(x['image'], tf.float32)
train_data = train_data.batch(4).map(preprocess)
test_data = test_data.batch(4).map(preprocess)
x = next(iter(train_data))
glow = tfb.Glow(output_shape=info.features['image'].shape,
coupling_bijector_fn=tfb.GlowDefaultNetwork,
exit_bijector_fn=tfb.GlowDefaultExitNetwork)
z_shape = glow.inverse_event_shape(info.features['image'].shape)
pz = tfd.Sample(tfd.Normal(0., 1.), z_shape)
# Calling glow on distribution p(z) creates our glow distribution over images.
px = glow(pz)
# Take samples from the distribution to get images from your dataset
images = px.sample(4)
# Map images to positions in the distribution
z = glow.inverse(x)
# Get the z's corresponding to each spatial scale. To do this, we have to
# find out how many zs are passed through blockwise at each stage that were
# not passed at the previous stage. This is encoded in the second element of
# each list of blockwise splits. However because the bijector iteratively
# converts spatial pixels to channels, we also need to multiply the size of
# that second element by the number of spatial-to-channel conversions that the
# tensor receives after exiting (including after any alteration).
ztake = [bs[1] * 4**(i+2) for i, bs in enumerate(glow.blockwise_splits)]
total_z_taken = sum(ztake)
split_sizes = [z_shape.as_list()[0]-total_z_taken] + ztake
zsplits = tf.split(z, num_or_size_splits=split_sizes, axis=-1)
References:
[1]: Diederik P Kingma, Prafulla Dhariwal, Glow: Generative Flow with Invertible 1x1 Convolutions. In Neural Information Processing Systems, 2018. https://arxiv.org/abs/1807.03039
[2]: Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. Density Estimation using Real NVP. In International Conference on Learning Representations, 2017. https://arxiv.org/abs/1605.08803
Args | |
---|---|
output_shape
|
A list of integers, specifying the event shape of the output, of the bijectors forward pass (the image). Specified as [H, W, C]. Default Value: (32, 32, 3) |
num_glow_blocks
|
An integer, specifying how many downsampling levels to include in the model. This must divide equally into both H and W, otherwise the bijector would not be invertible. Default Value: 3 |
num_steps_per_block
|
An integer specifying how many Affine Coupling and 1x1 convolution layers to include at each level of the spatial hierarchy. Default Value: 32 (i.e. the value used in the original glow paper). |
coupling_bijector_fn
|
A function which takes the argument input_shape
and returns a callable neural network (e.g. a keras.Sequential). The
network should either return a tensor with the same event shape as
input_shape (this will employ additive coupling), a tensor with the
same height and width as input_shape but twice the number of channels
(this will employ affine coupling), or a bijector which takes in a
tensor with event shape input_shape , and returns a tensor with shape
input_shape .
|
exit_bijector_fn
|
Similar to coupling_bijector_fn, exit_bijector_fn is
a function which takes the argument input_shape and output_chan
and returns a callable neural network. The neural network it returns
should take a tensor of shape input_shape as the input, and return
one of three options: A tensor with output_chan channels, a tensor
with 2 * output_chan channels, or a bijector. Additional details can
be found in the documentation for ExitBijector.
|
grab_after_block
|
A tuple of floats, specifying what fraction of the remaining channels to remove following each glow block. Glow will take the integer floor of this number multiplied by the remaining number of channels. The default is half at each spatial hierarchy. Default value: None (this will take out half of the channels after each block. |
use_actnorm
|
A bool deciding whether or not to use actnorm. Data-dependent
initialization is used to initialize this layer.
Default value: False
|
seed
|
A seed to control randomness in the 1x1 convolution initialization.
Default value: None (i.e., non-reproducible sampling).
|
validate_args
|
Python bool indicating whether arguments should be
checked for correctness.
Default value: False
|
name
|
Python str , name given to ops managed by this object.
Default value: 'glow' .
|
Attributes | |
---|---|
bijectors
|
|
blockwise_splits
|
|
dtype
|
|
forward_min_event_ndims
|
Returns the minimal number of dimensions bijector.forward operates on.
Multipart bijectors return structured |
graph_parents
|
Returns this Bijector 's graph_parents as a Python list.
|
has_static_min_event_ndims
|
Returns True if the bijector has statically-known min_event_ndims .
|
inverse_min_event_ndims
|
Returns the minimal number of dimensions bijector.inverse operates on.
Multipart bijectors return structured |
is_constant_jacobian
|
Returns true iff the Jacobian matrix is not a function of x. |
name
|
Returns the string name of this Bijector .
|
name_scope
|
Returns a tf.name_scope instance for this class.
|
parameters
|
Dictionary of parameters used to instantiate this Bijector .
|
submodules
|
Sequence of all sub-modules.
Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).
|
trainable_variables
|
Sequence of trainable variables owned by this module and its submodules. |
validate_args
|
Returns True if Tensor arguments will be validated. |
validate_event_size
|
|
variables
|
Sequence of variables owned by this module and its submodules. |
Methods
forward
forward(
x, name='forward', **kwargs
)
Returns the forward Bijector
evaluation, i.e., X = g(Y).
Args | |
---|---|
x
|
Tensor (structure). The input to the 'forward' evaluation.
|
name
|
The name to give this op. |
**kwargs
|
Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
Tensor (structure).
|
Raises | |
---|---|
TypeError
|
if self.dtype is specified and x.dtype is not
self.dtype .
|
NotImplementedError
|
if _forward is not implemented.
|
forward_dtype
forward_dtype(
dtype=UNSPECIFIED, name='forward_dtype', **kwargs
)
Returns the dtype returned by forward
for the provided input.
forward_event_ndims
forward_event_ndims(
event_ndims, **kwargs
)
Returns the number of event dimensions produced by forward
.
forward_event_shape
forward_event_shape(
input_shape
)
Shape of a single sample from a single batch as a TensorShape
.
Same meaning as forward_event_shape_tensor
. May be only partially defined.
Args | |
---|---|
input_shape
|
TensorShape (structure) indicating event-portion shape
passed into forward function.
|
Returns | |
---|---|
forward_event_shape_tensor
|
TensorShape (structure) indicating
event-portion shape after applying forward . Possibly unknown.
|
forward_event_shape_tensor
forward_event_shape_tensor(
input_shape, name='forward_event_shape_tensor'
)
Shape of a single sample from a single batch as an int32
1D Tensor
.
Args | |
---|---|
input_shape
|
Tensor , int32 vector (structure) indicating event-portion
shape passed into forward function.
|
name
|
name to give to the op |
Returns | |
---|---|
forward_event_shape_tensor
|
Tensor , int32 vector (structure)
indicating event-portion shape after applying forward .
|
forward_log_det_jacobian
forward_log_det_jacobian(
x, event_ndims, name='forward_log_det_jacobian', **kwargs
)
Returns both the forward_log_det_jacobian.
Args | |
---|---|
x
|
Tensor (structure). The input to the 'forward' Jacobian determinant
evaluation.
|
event_ndims
|
Number of dimensions in the probabilistic events being
transformed. Must be greater than or equal to
self.forward_min_event_ndims . The result is summed over the final
dimensions to produce a scalar Jacobian determinant for each event, i.e.
it has shape rank(x) - event_ndims dimensions.
Multipart bijectors require structured event_ndims, such that
rank(y[i]) - rank(event_ndims[i]) is the same for all elements i of
the structured input. Furthermore, the first event_ndims[i] of each
x[i].shape must be the same for all i (broadcasting is not allowed).
|
name
|
The name to give this op. |
**kwargs
|
Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
Tensor (structure), if this bijector is injective.
If not injective this is not implemented.
|
Raises | |
---|---|
TypeError
|
if y 's dtype is incompatible with the expected output dtype.
|
NotImplementedError
|
if neither _forward_log_det_jacobian
nor {_inverse , _inverse_log_det_jacobian } are implemented, or
this is a non-injective bijector.
|
inverse
inverse(
y, name='inverse', **kwargs
)
Returns the inverse Bijector
evaluation, i.e., X = g^{-1}(Y).
Args | |
---|---|
y
|
Tensor (structure). The input to the 'inverse' evaluation.
|
name
|
The name to give this op. |
**kwargs
|
Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
Tensor (structure), if this bijector is injective.
If not injective, returns the k-tuple containing the unique
k points (x1, ..., xk) such that g(xi) = y .
|
Raises | |
---|---|
TypeError
|
if y 's structured dtype is incompatible with the expected
output dtype.
|
NotImplementedError
|
if _inverse is not implemented.
|
inverse_dtype
inverse_dtype(
dtype=UNSPECIFIED, name='inverse_dtype', **kwargs
)
Returns the dtype returned by inverse
for the provided input.
inverse_event_ndims
inverse_event_ndims(
event_ndims, **kwargs
)
Returns the number of event dimensions produced by inverse
.
inverse_event_shape
inverse_event_shape(
output_shape
)
Shape of a single sample from a single batch as a TensorShape
.
Same meaning as inverse_event_shape_tensor
. May be only partially defined.
Args | |
---|---|
output_shape
|
TensorShape (structure) indicating event-portion shape
passed into inverse function.
|
Returns | |
---|---|
inverse_event_shape_tensor
|
TensorShape (structure) indicating
event-portion shape after applying inverse . Possibly unknown.
|
inverse_event_shape_tensor
inverse_event_shape_tensor(
output_shape, name='inverse_event_shape_tensor'
)
Shape of a single sample from a single batch as an int32
1D Tensor
.
Args | |
---|---|
output_shape
|
Tensor , int32 vector (structure) indicating
event-portion shape passed into inverse function.
|
name
|
name to give to the op |
Returns | |
---|---|
inverse_event_shape_tensor
|
Tensor , int32 vector (structure)
indicating event-portion shape after applying inverse .
|
inverse_log_det_jacobian
inverse_log_det_jacobian(
y, event_ndims, name='inverse_log_det_jacobian', **kwargs
)
Returns the (log o det o Jacobian o inverse)(y).
Mathematically, returns: log(det(dX/dY))(Y)
. (Recall that: X=g^{-1}(Y)
.)
Note that forward_log_det_jacobian
is the negative of this function,
evaluated at g^{-1}(y)
.
Args | |
---|---|
y
|
Tensor (structure). The input to the 'inverse' Jacobian determinant
evaluation.
|
event_ndims
|
Number of dimensions in the probabilistic events being
transformed. Must be greater than or equal to
self.inverse_min_event_ndims . The result is summed over the final
dimensions to produce a scalar Jacobian determinant for each event, i.e.
it has shape rank(y) - event_ndims dimensions.
Multipart bijectors require structured event_ndims, such that
rank(y[i]) - rank(event_ndims[i]) is the same for all elements i of
the structured input. Furthermore, the first event_ndims[i] of each
x[i].shape must be the same for all i (broadcasting is not allowed).
|
name
|
The name to give this op. |
**kwargs
|
Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
ildj
|
Tensor , if this bijector is injective.
If not injective, returns the tuple of local log det
Jacobians, log(det(Dg_i^{-1}(y))) , where g_i is the restriction
of g to the ith partition Di .
|
Raises | |
---|---|
TypeError
|
if x 's dtype is incompatible with the expected inverse-dtype.
|
NotImplementedError
|
if _inverse_log_det_jacobian is not implemented.
|
with_name_scope
@classmethod
with_name_scope( method )
Decorator to automatically enter the module name scope.
class MyModule(tf.Module):
@tf.Module.with_name_scope
def __call__(self, x):
if not hasattr(self, 'w'):
self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
return tf.matmul(x, self.w)
Using the above module would produce tf.Variable
s and tf.Tensor
s whose
names included the module name:
mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>
Args | |
---|---|
method
|
The method to wrap. |
Returns | |
---|---|
The original method wrapped such that it enters the module's name scope. |
__call__
__call__(
value, name=None, **kwargs
)
Applies or composes the Bijector
, depending on input type.
This is a convenience function which applies the Bijector
instance in
three different ways, depending on the input:
- If the input is a
tfd.Distribution
instance, returntfd.TransformedDistribution(distribution=input, bijector=self)
. - If the input is a
tfb.Bijector
instance, returntfb.Chain([self, input])
. - Otherwise, return
self.forward(input)
Args | |
---|---|
value
|
A tfd.Distribution , tfb.Bijector , or a (structure of) Tensor .
|
name
|
Python str name given to ops created by this function.
|
**kwargs
|
Additional keyword arguments passed into the created
tfd.TransformedDistribution , tfb.Bijector , or self.forward .
|
Returns | |
---|---|
composition
|
A tfd.TransformedDistribution if the input was a
tfd.Distribution , a tfb.Chain if the input was a tfb.Bijector , or
a (structure of) Tensor computed by self.forward .
|
Examples
sigmoid = tfb.Reciprocal()(
tfb.Shift(shift=1.)(
tfb.Exp()(
tfb.Scale(scale=-1.))))
# ==> `tfb.Chain([
# tfb.Reciprocal(),
# tfb.Shift(shift=1.),
# tfb.Exp(),
# tfb.Scale(scale=-1.),
# ])` # ie, `tfb.Sigmoid()`
log_normal = tfb.Exp()(tfd.Normal(0, 1))
# ==> `tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())`
tfb.Exp()([-1., 0., 1.])
# ==> tf.exp([-1., 0., 1.])
__eq__
__eq__(
other
)
Return self==value.