View source on GitHub |
Encapsulates a staged function.
oryx.core.state.function.FunctionModule(
variables: oryx.experimental.matching.jax_rewrite.Bindings
,
jaxpr: jax_core.ClosedJaxpr,
in_tree: Any,
out_tree: Any,
*,
name: Optional[str] = None
)
Methods
call
call(
*args, **kwargs
)
call_and_update
call_and_update(
*args, **kwargs
)
flatten
flatten()
replace
replace(
*, variables=None
)
unflatten
@classmethod
unflatten( data, variable_vals )
update
update(
*args, **kwargs
)
variables
variables()
__call__
__call__(
*args, **kwargs
) -> Any
Emulates a regular function call.
A Module
's dunder call will ensure state is updated after the function
call by calling assign
on the updated state before returning the output of
the function.
Args | |
---|---|
*args
|
The arguments to the module. |
**kwargs
|
The keyword arguments to the module. |
Returns | |
---|---|
The output of the module. |