View source on GitHub |
Transforms a program into one that draws samples on a named axis.
oryx.core.ppl.plate(
f: Optional[oryx.core.ppl.LogProbFunction
] = None,
name: Optional[str] = None
)
In graphical model parlance, a plate designates independent random variables.
The plate
transformation follows this idea, where a plate
-ed program
draws independent samples. Unlike jax.vmap
-ing a program, which also
produces independent samples with positional batch dimensions, plate
produces samples with implicit named axes. Named axis support is useful for
other JAX transformations like pmap
and xmap
.
Specifically, a plate
-ed program creates a different key for each axis
of the named axis. log_prob
reduces over the named axis to produce a single
value.
Example usage:
@ppl.plate(name='foo')
def model(key):
return random_variable(random.normal)(key)
# We can't call model directly because there are implicit named axes present
try:
model(random.PRNGKey(0))
except NameError:
print('No named axis present!')
# If we vmap with a named axis, we produce independent samples.
vmap(model, axis_name='foo')(random.split(random.PRNGKey(0), 3))
# ==> [0.58776844, -0.4009751, 0.01193586]
Args | |
---|---|
f
|
a Program to transform. If f is None , plate returns a decorator.
|
name
|
a str name for the plate which can used as a name axis in JAX
functions and transformations.
|
Returns | |
---|---|
A decorator if f is None or a transformed program if f is provided.
The transformed program behaves produces independent across a named
axis with name name .
|