Save the date! Google I/O returns May 18-20 Register now

Registers a rule for substituting distributions in ASVI surrogates.

condition Python callable that takes a Distribution instance and returns a Python bool indicating whether or not to substitute it. May also be a class type such as tfd.Normal, in which case the condition is interpreted as lambda distribution: isinstance(distribution, class).
substitution_fn Python callable that takes a Distribution instance and returns a new Distribution instance used to define the ASVI surrogate posterior. Note that this substitution does not modify the original model.


To use a Normal surrogate for all location-scale family distributions, we could register the substitution:
  condition=lambda distribution: (
    hasattr(distribution, 'loc') and hasattr(distribution, 'scale'))
  substitution_fn=lambda distribution: (
    # Invoking the event space bijector applies any relevant constraints,
    # e.g., that HalfCauchy samples must be `>= loc`.
      tfd.Normal(loc=distribution.loc, scale=distribution.scale)))

This rule will fire when ASVI encounters a location-scale distribution, and instructs ASVI to build a surrogate 'as if' the model had just used a (possibly constrained) Normal in its place. Note that we could have used a more precise condition, e.g., to limit the substitution to distributions with a specific name, if we had reason to think that a Normal distribution would be a good surrogate for some model variables but not others.