Conditions a probabilistic program on random variables.
oryx.core.ppl.conditional(
f: oryx.core.ppl.LogProbFunction,
names: Union[List[str], str]
) -> oryx.core.ppl.LogProbFunction
Used in the notebooks
conditional is a probabilistic program transformation that converts latent
random variables into conditional inputs to the program. The random variables
that are moved to the input are specified via a list of names that correspond
to tagged random samples from the program. The final arguments to the output
program correspond to the list of names passed into conditional.
Random variables that are conditioned are no longer random variables. This
means that if a variable x is conditioned on, it will no longer appear in
the joint_sample of a program and its log_prob will no longer be computed
as part of a program's log_prob.
Example:
def model(key):
k1, k2 = random.split(key)
z = random_variable(random.normal, name='z')(k1)
return z + random_variable(random.normal, name='x')(k2)
conditional(model, ['z'])(random.PRNGKey(0), 0.) # => -1.25153887
conditional(model, ['z'])(random.PRNGKey(0), 1.) # => -0.25153887
conditional(model, ['z'. 'x'])(random.PRNGKey(0), 1., 2.) # => 3.
Args |
f
|
A probabilistic program.
|
names
|
A string or list of strings correspond to random variable names in
f.
|
Returns |
|
A probabilistic program with additional conditional inputs.
|