View source on GitHub |
Contains utilities for writing JAX patterns and rewriting JAX functions.
This module enables taking JAX functions (i.e. functions that take in JAX arrays and manipulate them with JAX primitives) and using pattern matching and term rewriting to transform them into new JAX functions.
Primer: JAXprs
In order to pattern match against and rewrite JAX functions, we first convert a JAX function into a JAXpr (JAX expression). JAXprs are an intermediate representation that JAX often uses for function transformation. A JAXpr is an A-normal form representation of a function where a set of equations operate on named values in an environment sequentially. For example, the JAXpr corresponding to the JAX function
def f(x):
return jnp.exp(x) + 1.
is
{ lambda ; a.
let b = exp a
c = add b 1.0
in (c,) }
where a
is the input to the JAXpr and c
is the output. We can convert
functions into JAXprs by tracing the function with abstract values, as done in
jax.make_jaxpr
or oryx.core.trace_util.stage
.
JAX expressions
We can think of JAXprs as an edge-list representation of a computation graph;
we convert them into an expression tree representation more amenable to pattern
matching and rewriting. We do this with a custom JAXpr interpreter (see
jaxpr_to_expressions
) that returns a JaxExpression
for each of the outputs
of the JAXpr. JaxExpression
s are registered with the pattern matching and term
rewriting system, so we can write rules to transform them (see rules
).
The basic elements of a JaxExpression
parallel abstract values in JAX, a.k.a.
a shape
and a dtype
. We can also "evaluate" JaxExpression
s using the
evaluate
function, which will take in an expression and an environment (a
binding of names to JAX values) and produce a JAX array result. Evaluating
enables us to transform JAX functions into expressions and back into JAX
functions.
We'll quickly go over the core JaxExpression
s that may comprise an expression
tree.
JaxVar
A JaxVar
corresponds to a named input to a JAXpr. Evaluating a JaxVar
is just looking up a name in the evaluation environment.
Example:
a = JaxVar('x', shape=(), dtype=jnp.float32)
evaluate(a, {'x': 5.}) # ==> 5.
Literal
A Literal
corresponds to a literal value in a JAXpr, or values in a JAXpr that
are inlined into the equations, like scalars. Evaluating a Literal
involves
returning the literal value it was instantiated with.
Example:
a = Literal(1.)
evaluate(a, {}) # ==> 1.
Primitive
Perhaps the most important expression is a Primitive
, which corresponds to an
equation in a JAXpr. A Primitive
is a combination of a JAX primitive, a tuple
of expression operands, and an instance of Params
, which correspond to the
parameters of the JAX primitive. Evaluating a Primitive
involves first
recursively evaluating the operands, then calling
primitive.bind(*operands, **params)
to get the resulting JAX value.
Example:
a = Primitive(lax.exp_p, (Literal(0.),), Params())
evaluate(a, {}) # ==> 1.
b = Primitive(lax.exp_p, (JaxVar('x', (), jnp.float32),), Params())
evaluate(b, {'x': 0.}) # ==> 1.
CallPrimitive
JAXprs can contain other JAXprs by means of "call primitives", which correspond
to transformations like jax.jit
and jax.pmap
. These call primitives
encapsulate another JAXpr, which is evaluated when the call primitive is
evaluated. A CallPrimitive
expression recursively converts the nested JAXpr
into an expression and is evaluated by rebinding names and recursively
evaluating its containing expression.
Example:
expr = CallPrimitive(
primitive=xla.xla_call_p,
operands=(JaxVar('a', (), jnp.float32),),
expression=(
Primitive(lax.exp_p, (JaxVar('b', (), jnp.float32),), Params()),),
params=Params(),
variable_names=['b'])
evaluate(expr, {'a': 0.}) # ==> 1.
Other relevant expressions
Part
- used to handle primitives with multiple outputs.Part(expr, i)
corresponds to indexing into the multi-part expressionexpr
with indexi
.BoundExpression
- used to pre-bind some names to values in an expression so the names don't have to be bound when callingevaluate
. For example, this is used to encapsulate a JAXpr and any constant values in it.
Rewriting JAX expressions
Now that we have a set of baseline JAX expressions, we can write patterns using
matcher
and rewrite rules using rules
.
Let's say we want to turn all calls to exp
into calls to log
. We first
will write some convenience functions for constructing patterns and rules.
Exp = lambda x: Primitive(lax.exp_p, (x,), Params())
Log = lambda x: Primitive(lax.log_p, (x,), Params())
We can then write our pattern and accompanying rule.
exp_pattern = Exp(matcher.Var('x'))
exp_to_log_rule = rules.make_rule(exp_pattern, lambda x: Log(x))
We can now rewrite an example expression.
expr = Exp(Literal(5.))
new_expr = exp_to_log_rule(expr)
assert new_expr == Log(Literal(5.))
All of the useful machinery in matcher
and rules
work with JaxExpression
s,
so we can make complex rewrite rules without too much work. For example, we can
use matcher.Segment
s to obtain all of the operands of an expression.
Add = lambda *args: Primitive(lax.add_p, args, Params())
Sub = lambda *args: Primitive(lax.sub_p, args, Params())
add_pattern = Add(matcher.Segment('args'))
add_to_sub_rule = rules.make_rule(add_pattern, lambda args: Sub(*args))
Rewriting JAX functions
We provide a JAX function transformation rewrite
that when provided a set of
rewrite rules will transform an input function by first converting it into an
expression, applying the rules to rewrite the expression, then evaluating the
rewritten expression with the function's inputs.
Example:
Exp = lambda x: Primitive(lax.exp_p, (x,), Params())
Log = lambda x: Primitive(lax.log_p, (x,), Params())
exp_pattern = Exp(matcher.Var('x'))
exp_to_log_rule = rules.term_rewriter(
rules.make_rule(exp_pattern, lambda x: Log(x)))
def f(x):
return jnp.exp(x) + 1.
new_f = rewrite(f, exp_to_log_rule)
f(1.) # ==> 1. (i.e. log(1.) + 1.)
Modules
matcher
module: A basic pattern matching system.
rules
module: A simple term-rewriting system.
trace_util
module: Module for JAX tracing utility functions.
Classes
class BoundExpression
: Represents JAX expressions with closed over constants.
class CallPrimitive
: Encapsulates JAX CallPrimitive
s like jit
and pmap
.
class JaxExpression
: A node in an expression tree.
class JaxVar
: An expression that looks up a provided name in the environment.
class Literal
: A expression that evaluates to a provided scalar literal value.
class Params
: An immutable dictionary used to represent parameters of JAX primitives.
class Part
: Used to select the outputs of JAX primitives with multiple outputs.
class PjitPrimitive
: Encapsulates JAX's pjit primitive.
class Primitive
: A JAXExpression
corresponding to a jax.core.Primitive
.
Functions
evaluate(...)
: Evaluates expressions into JAX values.
evaluate_jax_expression(...)
: Evaluates a JaxExpression
in an environment.
evaluate_value(...)
: Default evaluate function for numerical types.
jaxpr_to_expressions(...)
: Converts a JAXpr into a tuple of output JaxExpression
s.
make_bound_expression(...)
: Returns a function that traces a function to produce an expression.
primitive_to_expression(...)
: Converts a JAX primitive into a Primitive
expression.
rewrite(...)
: Rewrites a JAX function according to a rewrite rule.
Type Aliases
Other Members | |
---|---|
custom_expressions |
|