Returns a function transformation that applies a provided set of handlers.
oryx.core.ppl.make_effect_handler(
handlers: Dict[jax_core.Primitive, oryx.core.ppl.LogProbFunction
]
) -> Callable[[Callable[..., Any]], Callable[..., Any]]
Args |
handlers
|
A dict that maps JAX primitives to callback functions that take
in (state, *args) and return (output, new_state) . When running the the
transformed function, the execution of primitives in handlers will be
delegated to the callback functions rather than their default execution
rules.
|
Returns |
A function transformation that applies the handlers to an input function.
The transformation takes in an input function and returns a transformed
function that takes an additional initial state argument and returns
an additional output state .
|