ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tfp.experimental.util.make_trainable

Constructs a distribution or bijector instance with trainable parameters.

Used in the notebooks

Used in the tutorials

This is a convenience method that instantiates a class using tf.Variables for its underlying trainable parameters. Parameters are randomly initialized, and transformed to enforce any domain constraints. This method assumes that the class exposes a parameter_properties method annotating its trainable parameters, and that the caller provides any additional (non-trainable) arguments required by the class.

cls Python class that implements cls.parameter_properties(), e.g., a TFP distribution (tfd.Normal) or bijector (tfb.Scale).
initial_parameters a dictionary containing initial values for some or all of the parameters to cls, OR a Python callable with signature value = parameter_init_fn(parameter_name, shape, dtype, seed, constraining_bijector). If a dictionary is provided, any parameters not specified will be initialized to a random value in their domain. Default value: None (equivalent to {}; all parameters are initialized randomly).
batch_and_event_shape Optional int Tensor desired shape of samples (for distributions) or inputs (for bijectors), used to determine the shape of the trainable parameters. Default value: ().
parameter_dtype Optional float dtype for trainable variables.
seed PRNG seed; see tfp.random.sanitize_seed for details. Default value: None.
**init_kwargs Additional keyword arguments passed to cls.__init__() to specify any non-trainable parameters. If a value is passed for an otherwise-trainable parameter---for example, trainable(tfd.Normal, scale=1.)---it will be taken as a fixed value and no variable will be constructed for that parameter.

trainable_instance an instance of cls parameterized by trainable variables.

Example

Suppose we want to fit a normal distribution to observed data. We could of course just examine the empirical mean and standard deviation of the data:

samples = [4.57, 6.37, 5.93, 7.98, 2.03, 3.59, 8.55, 3.45, 5.06, 6.44]
model = tfd.Normal(
  loc=tf.reduce_mean(samples),  # ==> 5.40
  scale=tf.math.reduce_std(sample))  # ==> 1.95

and this would be a very sensible approach. But that's boring, so instead, let's do way more work to get the same result. We'll build a trainable normal distribution, and explicitly optimize to find the maximum-likelihood estimate for the parameters given our data:

model = tfp.util.make_trainable(tfd.Normal)
losses = tfp.math.minimize(
  lambda: -model.log_prob(samples),
  optimizer=tf.optimizers.Adam(0.1),
  num_steps=200)
print('Fit Normal distribution with mean {} and stddev {}'.format(
  model.mean(),
  model.stddev()))

In this trivial case, doing the explicit optimization has few advantages over the first approach in which we simply matched the empirical moments of the data. However, trainable distributions are useful more generally. For example, they can enable maximum-likelihood estimation of distributions when a moment-matching estimator is not available, and they can also serve as surrogate posteriors in variational inference.