ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

oryx.experimental.matching.jax_rewrite.Part

Used to select the outputs of JAX primitives with multiple outputs.

Inherits From: JaxExpression, Expression, Pattern

When a JAX primitive has multiple_results = True, it returns several outputs when called. To represent multiple outputs in an expression tree, we wrap the output of a multiple-output primitive with Part with an index for each of its outputs. Part is primarily used with CallPrimitives, which always have multiple outputs.

operand An expression that can be indexed into with an integer i.e. operand[i].
index The index that is used when accessing the operand.
dtype

shape

Methods

evaluate

View source

match

View source

tree_children

View source

tree_map

View source

__eq__