View source on GitHub
|
Contains utilities for writing JAX patterns and rewriting JAX functions.
This module enables taking JAX functions (i.e. functions that take in JAX arrays and manipulate them with JAX primitives) and using pattern matching and term rewriting to transform them into new JAX functions.
Primer: JAXprs
In order to pattern match against and rewrite JAX functions, we first convert a JAX function into a JAXpr (JAX expression). JAXprs are an intermediate representation that JAX often uses for function transformation. A JAXpr is an A-normal form representation of a function where a set of equations operate on named values in an environment sequentially. For example, the JAXpr corresponding to the JAX function
def f(x):
return jnp.exp(x) + 1.
is
{ lambda ; a.
let b = exp a
c = add b 1.0
in (c,) }
where a is the input to the JAXpr and c is the output. We can convert
functions into JAXprs by tracing the function with abstract values, as done in
jax.make_jaxpr or oryx.core.trace_util.stage.
JAX expressions
We can think of JAXprs as an edge-list representation of a computation graph;
we convert them into an expression tree representation more amenable to pattern
matching and rewriting. We do this with a custom JAXpr interpreter (see
jaxpr_to_expressions) that returns a JaxExpression for each of the outputs
of the JAXpr. JaxExpressions are registered with the pattern matching and term
rewriting system, so we can write rules to transform them (see rules).
The basic elements of a JaxExpression parallel abstract values in JAX, a.k.a.
a shape and a dtype. We can also "evaluate" JaxExpressions using the
evaluate function, which will take in an expression and an environment (a
binding of names to JAX values) and produce a JAX array result. Evaluating
enables us to transform JAX functions into expressions and back into JAX
functions.
We'll quickly go over the core JaxExpressions that may comprise an expression
tree.
JaxVar
A JaxVar corresponds to a named input to a JAXpr. Evaluating a JaxVar
is just looking up a name in the evaluation environment.
Example:
a = JaxVar('x', shape=(), dtype=jnp.float32)
evaluate(a, {'x': 5.}) # ==> 5.
Literal
A Literal corresponds to a literal value in a JAXpr, or values in a JAXpr that
are inlined into the equations, like scalars. Evaluating a Literal involves
returning the literal value it was instantiated with.
Example:
a = Literal(1.)
evaluate(a, {}) # ==> 1.
Primitive
Perhaps the most important expression is a Primitive, which corresponds to an
equation in a JAXpr. A Primitive is a combination of a JAX primitive, a tuple
of expression operands, and an instance of Params, which correspond to the
parameters of the JAX primitive. Evaluating a Primitive involves first
recursively evaluating the operands, then calling
primitive.bind(*operands, **params) to get the resulting JAX value.
Example:
a = Primitive(lax.exp_p, (Literal(0.),), Params())
evaluate(a, {}) # ==> 1.
b = Primitive(lax.exp_p, (JaxVar('x', (), jnp.float32),), Params())
evaluate(b, {'x': 0.}) # ==> 1.
CallPrimitive
JAXprs can contain other JAXprs by means of "call primitives", which correspond
to transformations like jax.jit and jax.pmap. These call primitives
encapsulate another JAXpr, which is evaluated when the call primitive is
evaluated. A CallPrimitive expression recursively converts the nested JAXpr
into an expression and is evaluated by rebinding names and recursively
evaluating its containing expression.
Example:
expr = CallPrimitive(
primitive=xla.xla_call_p,
operands=(JaxVar('a', (), jnp.float32),),
expression=(
Primitive(lax.exp_p, (JaxVar('b', (), jnp.float32),), Params()),),
params=Params(),
variable_names=['b'])
evaluate(expr, {'a': 0.}) # ==> 1.
Other relevant expressions
Part- used to handle primitives with multiple outputs.Part(expr, i)corresponds to indexing into the multi-part expressionexprwith indexi.BoundExpression- used to pre-bind some names to values in an expression so the names don't have to be bound when callingevaluate. For example, this is used to encapsulate a JAXpr and any constant values in it.
Rewriting JAX expressions
Now that we have a set of baseline JAX expressions, we can write patterns using
matcher and rewrite rules using rules.
Let's say we want to turn all calls to exp into calls to log. We first
will write some convenience functions for constructing patterns and rules.
Exp = lambda x: Primitive(lax.exp_p, (x,), Params())
Log = lambda x: Primitive(lax.log_p, (x,), Params())
We can then write our pattern and accompanying rule.
exp_pattern = Exp(matcher.Var('x'))
exp_to_log_rule = rules.make_rule(exp_pattern, lambda x: Log(x))
We can now rewrite an example expression.
expr = Exp(Literal(5.))
new_expr = exp_to_log_rule(expr)
assert new_expr == Log(Literal(5.))
All of the useful machinery in matcher and rules work with JaxExpressions,
so we can make complex rewrite rules without too much work. For example, we can
use matcher.Segments to obtain all of the operands of an expression.
Add = lambda *args: Primitive(lax.add_p, args, Params())
Sub = lambda *args: Primitive(lax.sub_p, args, Params())
add_pattern = Add(matcher.Segment('args'))
add_to_sub_rule = rules.make_rule(add_pattern, lambda args: Sub(*args))
Rewriting JAX functions
We provide a JAX function transformation rewrite that when provided a set of
rewrite rules will transform an input function by first converting it into an
expression, applying the rules to rewrite the expression, then evaluating the
rewritten expression with the function's inputs.
Example:
Exp = lambda x: Primitive(lax.exp_p, (x,), Params())
Log = lambda x: Primitive(lax.log_p, (x,), Params())
exp_pattern = Exp(matcher.Var('x'))
exp_to_log_rule = rules.term_rewriter(
rules.make_rule(exp_pattern, lambda x: Log(x)))
def f(x):
return jnp.exp(x) + 1.
new_f = rewrite(f, exp_to_log_rule)
f(1.) # ==> 1. (i.e. log(1.) + 1.)
Modules
matcher module: A basic pattern matching system.
rules module: A simple term-rewriting system.
trace_util module: Module for JAX tracing utility functions.
Classes
class BoundExpression: Represents JAX expressions with closed over constants.
class CallPrimitive: Encapsulates JAX CallPrimitives like jit and pmap.
class JaxExpression: A node in an expression tree.
class JaxVar: An expression that looks up a provided name in the environment.
class Literal: A expression that evaluates to a provided scalar literal value.
class Params: An immutable dictionary used to represent parameters of JAX primitives.
class Part: Used to select the outputs of JAX primitives with multiple outputs.
class PjitPrimitive: Encapsulates JAX's pjit primitive.
class Primitive: A JAXExpression corresponding to a jax.core.Primitive.
Functions
evaluate(...): Evaluates expressions into JAX values.
evaluate_jax_expression(...): Evaluates a JaxExpression in an environment.
evaluate_value(...): Default evaluate function for numerical types.
jaxpr_to_expressions(...): Converts a JAXpr into a tuple of output JaxExpressions.
make_bound_expression(...): Returns a function that traces a function to produce an expression.
primitive_to_expression(...): Converts a JAX primitive into a Primitive expression.
rewrite(...): Rewrites a JAX function according to a rewrite rule.
Type Aliases
Other Members | |
|---|---|
| custom_expressions |
|
View source on GitHub