View source on GitHub |
Converts a JAX primitive into a Primitive
expression.
oryx.experimental.matching.jax_rewrite.primitive_to_expression(
prim: jax_core.Primitive
) -> Callable[[Tuple[Any], oryx.experimental.matching.jax_rewrite.Params
], oryx.experimental.matching.jax_rewrite.Primitive
]
Args | |
---|---|
prim
|
A jax.core.Primitive to be converted into an expression.
|
Returns | |
---|---|
A function that returns an expression when provided operands and parameters. |