View source on GitHub |
Regularizer that adds a KL divergence penalty to the model loss.
tfp.layers.KLDivergenceRegularizer(
distribution_b,
use_exact_kl=False,
test_points_reduce_axis=(),
test_points_fn=tfp.experimental.distributions.marginal_fns.ps.convert_to_shape_tensor
,
weight=None
)
When using Monte Carlo approximation (e.g., use_exact=False
), it is presumed
that the input distribution's concretization (i.e.,
tf.convert_to_tensor(distribution)
) corresponds to a random sample. To
override this behavior, set test_points_fn
.
Example
tfd = tfp.distributions
tfpl = tfp.layers
tfk = tf.keras
tfkl = tf.keras.layers
# Create a variational encoder and add a KL Divergence penalty to the
# loss that encourages marginal coherence with a unit-MVN (the "prior").
input_shape = [28, 28, 1]
encoded_size = 2
variational_encoder = tfk.Sequential([
tfkl.InputLayer(input_shape=input_shape),
tfkl.Flatten(),
tfkl.Dense(10, activation='relu'),
tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size)),
tfpl.MultivariateNormalTriL(
encoded_size,
lambda s: s.sample(10),
activity_regularizer=tfpl.KLDivergenceRegularizer(
tfd.MultivariateNormalDiag(loc=tf.zeros(encoded_size)),
weight=num_train_samples)),
])
Args | |
---|---|
distribution_b
|
distribution instance corresponding to b as in
KL[a, b] . The previous layer's output is presumed to be a
Distribution instance and is a ).
|
use_exact_kl
|
Python bool indicating if KL divergence should be
calculated exactly via tfp.distributions.kl_divergence or via Monte
Carlo approximation.
Default value: False .
|
test_points_reduce_axis
|
int vector or scalar representing dimensions
over which to reduce_mean while calculating the Monte Carlo
approximation of the KL divergence. As is with all tf.reduce_* ops,
None means reduce over all dimensions; () means reduce over none of
them.
Default value: () (i.e., no reduction).
|
test_points_fn
|
Python callable taking a Distribution instance and
returning a Tensor used for random test points to approximate the KL
divergence.
Default value: tf.convert_to_tensor .
|
weight
|
Multiplier applied to the calculated KL divergence for each Keras
batch member.
Default value: None (i.e., do not weight each batch member).
|
Attributes | |
---|---|
distribution_b
|
|
name
|
Returns the name of this module as passed or determined in the ctor. |
name_scope
|
Returns a tf.name_scope instance for this class.
|
non_trainable_variables
|
Sequence of non-trainable variables owned by this module and its submodules. |
submodules
|
Sequence of all sub-modules.
Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).
|
test_points_fn
|
|
test_points_reduce_axis
|
|
trainable_variables
|
Sequence of trainable variables owned by this module and its submodules. |
use_exact_kl
|
|
variables
|
Sequence of variables owned by this module and its submodules. |
weight
|
Methods
from_config
@classmethod
from_config( config )
Creates a regularizer from its config.
This method is the reverse of get_config
,
capable of instantiating the same regularizer from the config
dictionary.
This method is used by Keras model_to_estimator
, saving and
loading models to HDF5 formats, Keras model cloning, some visualization
utilities, and exporting models to and from JSON.
Args | |
---|---|
config
|
A Python dictionary, typically the output of get_config. |
Returns | |
---|---|
A regularizer instance. |
get_config
get_config()
Returns the config of the regularizer.
An regularizer config is a Python dictionary (serializable) containing all configuration parameters of the regularizer. The same regularizer can be reinstantiated later (without any saved state) from this configuration.
This method is optional if you are just training and executing models, exporting to and from SavedModels, or using weight checkpoints.
This method is required for Keras model_to_estimator
, saving and
loading models to HDF5 formats, Keras model cloning, some visualization
utilities, and exporting models to and from JSON.
Returns | |
---|---|
Python dictionary. |
with_name_scope
@classmethod
with_name_scope( method )
Decorator to automatically enter the module name scope.
class MyModule(tf.Module):
@tf.Module.with_name_scope
def __call__(self, x):
if not hasattr(self, 'w'):
self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
return tf.matmul(x, self.w)
Using the above module would produce tf.Variable
s and tf.Tensor
s whose
names included the module name:
mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>
Args | |
---|---|
method
|
The method to wrap. |
Returns | |
---|---|
The original method wrapped such that it enters the module's name scope. |
__call__
__call__(
distribution_a
)
Compute a regularization penalty from an input tensor.