View source on GitHub |
Rewrites a JAX function according to a rewrite rule.
oryx.experimental.matching.jax_rewrite.rewrite(
f: oryx.core.ppl.LogProbFunction
,
rule: rules.Rule
) -> oryx.core.ppl.LogProbFunction
Args | |
---|---|
f
|
A function to be transformed. |
rule
|
A function that transforms a rules.Expression into another.
|
Returns | |
---|---|
A function that when called with the original arguments to f executes the
body of f rewritten according to the provided rule .
|