ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tfp.experimental.vi.build_split_flow_surrogate_posterior

Builds a joint variational posterior by splitting a normalizing flow.

event_shape (Nested) event shape of the surrogate posterior.
trainable_bijector A trainable tfb.Bijector instance that operates on Tensors (not structures), e.g. tfb.MaskedAutoregressiveFlow or tfb.RealNVP. This bijector transforms the base distribution before it is split.
constraining_bijector tfb.Bijector instance, or nested structure of tfb.Bijector instances, that maps (nested) values in R^n to the support of the posterior. (This can be the experimental_default_event_space_bijector of the distribution over the prior latent variables.) Default value: None (i.e., the posterior is over R^n).
base_distribution A tfd.Distribution subclass parameterized by loc and scale. The base distribution for the transformed surrogate has loc=0. and scale=1.. Default value: tfd.Normal.
batch_shape The batch_shape of the output distribution. Default value: ().
dtype The dtype of the surrogate posterior. Default value: tf.float32.
validate_args Python bool. Whether to validate input with asserts. This imposes a runtime cost. If validate_args is False, and the inputs are invalid, correct behavior is not guaranteed. Default value: False.
name Python str name prefixed to ops created by this function. Default value: None (i.e., 'build_split_flow_surrogate_posterior').

surrogate_distribution Trainable tfd.TransformedDistribution with event shape equal to event_shape.

Examples


# Train a normalizing flow on the Eight Schools model [1].

treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]
model = tfd.JointDistributionNamed({
    'avg_effect':
        tfd.Normal(loc=0., scale=10., name='avg_effect'),
    'log_stddev':
        tfd.Normal(loc=5., scale=1., name='log_stddev'),
    'school_effects':
        lambda log_stddev, avg_effect: (
            tfd.Independent(
                tfd.Normal(
                    loc=avg_effect[..., None] * tf.ones(8),
                    scale=tf.exp(log_stddev[..., None]) * tf.ones(8),
                    name='school_effects'),
                reinterpreted_batch_ndims=1)),
    'treatment_effects': lambda school_effects: tfd.Independent(
        tfd.Normal(loc=school_effects, scale=treatment_stddevs),
        reinterpreted_batch_ndims=1)
})

# Pin the observed values in the model.
target_model = model.experimental_pin(treatment_effects=treatment_effects)

# Create a Masked Autoregressive Flow bijector.
net = tfb.AutoregressiveNetwork(2, hidden_units=[16, 16], dtype=tf.float32)
maf = tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=net)

# Build and fit the surrogate posterior.
surrogate_posterior = (
    tfp.experimental.vi.build_split_flow_surrogate_posterior(
        event_shape=target_model.event_shape_tensor(),
        trainable_bijector=maf,
        constraining_bijector=(
            target_model.experimental_default_event_space_bijector())))

losses = tfp.vi.fit_surrogate_posterior(
    target_model.unnormalized_log_prob,
    surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

References

[1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and Donald Rubin. Bayesian Data Analysis, Third Edition. Chapman and Hall/CRC, 2013.