Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge

Module: oryx.experimental.matching.jax_rewrite

Contains utilities for writing JAX patterns and rewriting JAX functions.

This module enables taking JAX functions (i.e. functions that take in JAX arrays and manipulate them with JAX primitives) and using pattern matching and term rewriting to transform them into new JAX functions.

Primer: JAXprs

In order to pattern match against and rewrite JAX functions, we first convert a JAX function into a JAXpr (JAX expression). JAXprs are an intermediate representation that JAX often uses for function transformation. A JAXpr is an A-normal form representation of a function where a set of equations operate on named values in an environment sequentially. For example, the JAXpr corresponding to the JAX function

def f(x):
  return jnp.exp(x) + 1.


{ lambda  ; a.
  let b = exp a
      c = add b 1.0
  in (c,) }

where a is the input to the JAXpr and c is the output. We can convert functions into JAXprs by tracing the function with abstract values, as done in jax.make_jaxpr or oryx.core.trace_util.stage.

JAX expressions

We can think of JAXprs as an edge-list representation of a computation graph; we convert them into an expression tree representation more amenable to pattern matching and rewriting. We do this with a custom JAXpr interpreter (see jaxpr_to_expressions) that returns a JaxExpression for each of the outputs of the JAXpr. JaxExpressions are registered with the pattern matching and term rewriting system, so we can write rules to transform them (see rules).

The basic elements of a JaxExpression parallel abstract values in JAX, a.k.a. a shape and a dtype. We can also "evaluate" JaxExpressions using the evaluate function, which will take in an expression and an environment (a binding of names to JAX values) and produce a JAX array result. Evaluating enables us to transform JAX functions into expressions and back into JAX functions.

We'll quickly go over the core JaxExpressions that may comprise an expression tree.


A JaxVar corresponds to a named input to a JAXpr. Evaluating a JaxVar is just looking up a name in the evaluation environment.


a = JaxVar('x', shape=(), dtype=jnp.float32)
evaluate(a, {'x': 5.}) # ==> 5.


A Literal corresponds to a literal value in a JAXpr, or values in a JAXpr that are inlined into the equations, like scalars. Evaluating a Literal involves returning the literal value it was instantiated with.


a = Literal(1.)
evaluate(a, {}) # ==> 1.


Perhaps the most important expression is a Primitive, which corresponds to an equation in a JAXpr. A Primitive is a combination of a JAX primitive, a tuple of expression operands, and an instance of Params, which correspond to the parameters of the JAX primitive. Evaluating a Primitive involves first recursively evaluating the operands, then calling primitive.bind(*operands, **params) to get the resulting JAX value.


a = Primitive(lax.exp_p, (Literal(0.),), Params())
evaluate(a, {}) # ==> 1.
b = Primitive(lax.exp_p, (JaxVar('x', (), jnp.float32),), Params())
evaluate(b, {'x': 0.}) # ==> 1.


JAXprs can contain other JAXprs by means of "call primitives", which correspond to transformations like jax.jit and jax.pmap. These call primitives encapsulate another JAXpr, which is evaluated when the call primitive is evaluated. A CallPrimitive expression recursively converts the nested JAXpr into an expression and is evaluated by rebinding names and recursively evaluating its containing expression.


expr = CallPrimitive(
    operands=(JaxVar('a', (), jnp.float32),),
        Primitive(lax.exp_p, (JaxVar('b', (), jnp.float32),), Params()),),

evaluate(expr, {'a': 0.}) # ==> 1.

Other relevant expressions

  • Part - used to handle primitives with multiple outputs. Part(expr, i) corresponds to indexing into the multi-part expression expr with index i.

  • BoundExpression - used to pre-bind some names to values in an expression so the names don't have to be bound when calling evaluate. For example, this is used to encapsulate a JAXpr and any constant values in it.

Rewriting JAX expressions

Now that we have a set of baseline JAX expressions, we can write patterns using matcher and rewrite rules using rules.

Let's say we want to turn all calls to exp into calls to log. We first will write some convenience functions for constructing patterns and rules.

Exp = lambda x: Primitive(lax.exp_p, (x,), Params())
Log = lambda x: Primitive(lax.log_p, (x,), Params())

We can then write our pattern and accompanying rule.

exp_pattern = Exp(matcher.Var('x'))
exp_to_log_rule = rules.make_rule(exp_pattern, lambda x: Log(x))

We can now rewrite an example expression.

expr = Exp(Literal(5.))
new_expr = exp_to_log_rule(expr)
assert new_expr == Log(Literal(5.))

All of the useful machinery in matcher and rules work with JaxExpressions, so we can make complex rewrite rules without too much work. For example, we can use matcher.Segments to obtain all of the operands of an expression.

Add = lambda *args: Primitive(lax.add_p, args, Params())
Sub = lambda *args: Primitive(lax.sub_p, args, Params())

add_pattern = Add(matcher.Segment('args'))
add_to_sub_rule = rules.make_rule(add_pattern, lambda args: Sub(*args))

Rewriting JAX functions

We provide a JAX function transformation rewrite that when provided a set of rewrite rules will transform an input function by first converting it into an expression, applying the rules to rewrite the expression, then evaluating the rewritten expression with the function's inputs.


Exp = lambda x: Primitive(lax.exp_p, (x,), Params())
Log = lambda x: Primitive(lax.log_p, (x,), Params())

exp_pattern = Exp(matcher.Var('x'))
exp_to_log_rule = rules.term_rewriter(
    rules.make_rule(exp_pattern, lambda x: Log(x)))

def f(x):
  return jnp.exp(x) + 1.
new_f = rewrite(f, exp_to_log_rule)
f(1.) # ==> 1. (i.e. log(1.) + 1.)


matcher module: A basic pattern matching system.

rules module: A simple term-rewriting system.

trace_util module: Module for JAX tracing utility functions.


class BoundExpression: Represents JAX expressions with closed over constants.

class CallPrimitive: Encapsulates JAX CallPrimitives like jit and pmap.

class JaxExpression: A node in an expression tree.

class JaxVar: An expression that looks up a provided name in the environment.

class Literal: A expression that evaluates to a provided scalar literal value.

class Params: An immutable dictionary used to represent parameters of JAX primitives.

class Part: Used to select the outputs of JAX primitives with multiple outputs.

class Primitive: A JAXExpression corresponding to a jax.core.Primitive.


evaluate(...): Evaluates expressions into JAX values.

evaluate_jax_expression(...): Evaluates a JaxExpression in an environment.


evaluate_value(...): Default evaluate function for numerical types.

jaxpr_to_expressions(...): Converts a JAXpr into a tuple of output JaxExpressions.

make_bound_expression(...): Returns a function that traces a function to produce an expression.

primitive_to_expression(...): Converts a JAX primitive into a Primitive expression.

rewrite(...): Rewrites a JAX function according to a rewrite rule.

Type Aliases