View source on GitHub |
Encapsulates JAX's pjit primitive.
Inherits From: CallPrimitive
, JaxExpression
, Expression
, Pattern
oryx.experimental.matching.jax_rewrite.PjitPrimitive(
primitive: jax_core.Primitive,
operands: Sequence[Any],
expression: Any,
params: oryx.experimental.matching.jax_rewrite.Params
,
variable_names: Sequence[str]
)
Attributes | |
---|---|
dtype
|
|
shape
|
|
primitive
|
Dataclass field |
operands
|
Dataclass field |
expression
|
Dataclass field |
params
|
Dataclass field |
variable_names
|
Dataclass field |
Methods
evaluate
evaluate(
env: oryx.experimental.matching.jax_rewrite.Bindings
) -> Any
match
match(
expr: Expr,
bindings: oryx.experimental.matching.jax_rewrite.Bindings
,
succeed: oryx.experimental.matching.jax_rewrite.Continuation
) -> oryx.experimental.matching.jax_rewrite.Success
tree_children
tree_children() -> Iterator[Expr]
tree_map
tree_map(
fn
) -> 'PjitPrimitive'
__eq__
__eq__(
other
)