oryx.experimental.matching.jax_rewrite.PjitPrimitive

Encapsulates JAX's pjit primitive.

Inherits From: CallPrimitive, JaxExpression, Expression, Pattern

dtype

shape

primitive Dataclass field
operands Dataclass field
expression Dataclass field
params Dataclass field
variable_names Dataclass field

Methods

evaluate

View source

match

View source

tree_children

View source

tree_map

View source

__eq__