ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

Module: oryx.experimental.matching.rules

A simple term-rewriting system.

This module is a Python implementation of the rules combinator library found in rules by Alexey Radul.

Expressions

Term-rewriting involves searching expression trees for patterns and replacing those patterns with new expressions. In order to represent expression trees, this module provides a single-dispatch API to enable registering both builtin Python types like tuples and dicts, and registering custom classes as nodes in an expression tree.

The two methods are:

  1. tree_map(expr, fn), which takes in an expression expr and a function fn and returns expr but with fn mapped over its children.
  2. tree_children(expr), which takes in an expression expr and returns an iterator over its children. It acts similar to Python's iter, but we can provide a custom iterator over builtin types like dicts, as we'd like to iterate over the values in the dictionary, not the keys.

Both methods are single-dispatch and can be overridden for custom types. Alternatively, this module provides the Expression class, which has methods Expression.tree_map and Expression.tree_children and has been already registered with the two functions. If a custom type subclasses Expression it will thus be auto-registered with tree_map and tree_children.

Rules

A rule is a function that maps an input expression to an output expression. We can either construct rules using make_rule or with "rule combinators".

make_rule

A rule can be constructed from a pattern (see matcher) and a handler function. The rule function checks if the input expression matches the pattern, and if so, it passes the resulting bindings into handler, producing the output expression. If the input expression does not match pattern, it is returned as-is.

rule = make_rule(1, lambda: 2) # Replaces 1 with 2
rule(1) # ==> 2
rule(3) # ==> 3

We can use matcher combinators to construct more complex rules. For example, we can use a matcher.Var to add one to any input expression.

add_one = make_rule(Var('x'), lambda x: x + 1)
add_one(1) # ==> 2
add_one(2) # ==> 3

Alternatively, we can use matcher.Choice to add one only if the number is 1 or 2.

maybe_add_one = make_rule(Choice(1, 2, name='x'), lambda x: x + 1)
maybe_add_one(1) # ==> 2
maybe_add_one(2) # ==> 3
maybe_add_one(3) # ==> 3

Rule combinators

Another way to build more complex rules is to use "rule combinators", or higher-order rules, i.e. functions that take in a rule and return a new rule.

rule_list

rule_list takes in a variable number of rules and tries applying them in sequence until the expression changes, and then immediately returns the changed expression. If none of the input rules change the expression, the expression is returned unchanged. Even if a particular pattern matches, if the handler leaves the expression unchanged, rule_list will continue to the next rule. For example, the following rule, rule_list(make_rule(0, lambda: 0), make_rule(Var('x'), lambda x: 1/x)), will still divide by zero when called on 0.

Here are some more examples of using rule_list:

one_to_three = make_rule(1, lambda: 3)
three_to_one = make_rule(3, lambda: 1)
rule = rule_list(one_to_three, three_to_one)

rule(1) # ==> 3
rule(3) # ==> 1
rule(2) # ==> 2

in_order

in_order takes in a variable number of rules and applies them in sequence regardless of if the expression changes or not, and returns the final expression.

one_to_three = make_rule(1, lambda: 3)
three_to_one = make_rule(3, lambda: 1)
rule = in_order(one_to_three, three_to_one)

rule(1) # ==> 1
rule(3) # ==> 1
rule(2) # ==> 2

iterated

iterated takes in a rule and applies it to an expression over and over again until the expression does not change. Note that is possible to create an infinite loop with iterated if the rules cause the expression to cycle between values.

add_one_until_five = iterated(make_rule(
                                Var('x', restrictions=[lambda x: x < 5]),
                                lambda x: x + 1))

add_one_until_five(1) # ==> 5
add_one_until_five(-10) # ==> 5
add_one_until_five(7) # ==> 7

one_to_three = make_rule(1, lambda: 3)
three_to_one = make_rule(3, lambda: 1)
bad_rule = iterated(rule_list(one_to_three, three_to_one))

bad_rule(1) # ==> Infinite loop!

rewrite_subexpressions

rewrite_subexpressions takes in a rule and just applies it to the children of the input expression using tree_map. Types like ints and strs have no children, so rewrite_subexpressions will leave them unchanged, but types like dicts and tuples do have children, so rewrite_subexpressions will apply a rule to their contained values. Note that rewrite_subexpressions does not recurse into the children of children, and only applies the rule to the direct children of the expression.

one_to_three = make_rule(1, lambda: 3)
three_to_one = make_rule(3, lambda: 1)
rule = rewrite_subexpressions(rule_list(one_to_three, three_to_one))

rule(1) # ==> 1
rule((1, 3)) # ==> (3, 1)
rule(((1, 3),)) # ==> ((1, 3),)

bottom_up

bottom_up recursively applies a rule to an expression, rewriting children first, resulting in a bottom-up rewrite of the expression tree (i.e. one that begins at the leaves and ends with the root). Unlike rewrite_subexpressions, it does rewrite the provided expression, and does recurse through all children.

one_to_three = make_rule(1, lambda: 3)
three_to_one = make_rule(3, lambda: 1)
rule = bottom_up(rule_list(one_to_three, three_to_one))

rule(1) # ==> 3
rule((1, 3)) # ==> (3, 1)
rule(((1, 3),)) # ==> ((3, 1),)

To see how the ordering can affect the rewrite, see the following example where we add one to the children of a tuple before taking its sum:

Integer = lambda name: Var(name, restrictions=[
    lambda x: isinstance(x, int)])
Tuple = lambda name: Var(name, restrictions=[lambda x: isinstance(x, tuple)])

rule = bottom_up(in_order(
         make_rule(Integer('a'), lambda a: a + 1.),
         make_rule(Tuple('t'), lambda t: sum(t))
       ))
rule((1., 2., 3.)) # ==> 9

top_down

top_down recursively applies a rule to an expression, rewriting the root expression first and then recursing down into the children, resulting in a top-down rewrite of the expression tree (i.e. one that begins at the root and ends with the leaves). Unlike rewrite_subexpressions, it does rewrite the provided expression, and does recurse through all children.

one_to_three = make_rule(1, lambda: 3)
three_to_one = make_rule(3, lambda: 1)
rule = top_down(rule_list(one_to_three, three_to_one))

rule(1) # ==> 3
rule((1, 3)) # ==> (3, 1)
rule(((1, 3),)) # ==> ((3, 1),)

To see how the ordering can affect the rewrite, see the following example where we add one to the children of a tuple after taking its sum:

Integer = lambda name: Var(name, restrictions=[
    lambda x: isinstance(x, int)])
Tuple = lambda name: Var(name, restrictions=[lambda x: isinstance(x, tuple)])

rule = top_down(in_order(
         make_rule(Integer('a'), lambda a: a + 1.),
         make_rule(Tuple('t'), lambda t: sum(t))
       ))
rule((1., 2., 3.)) # ==> 7

term_rewriter

The term_rewriter combinator composes bottom_up, iterated and rule_list, resulting in a rule combinator that takes in a list of rules and recursively applies it on an input expression over and over again until the expression does not change anymore.

It is provided as a convenience for term-rewriting applications like algebraic simplification that involve applying series of general rewrite rules to all parts of an expression.

Positive = lambda name: Var(name, restrictions=[
    lambda x: isinstance(x, int), lambda x: x > 0])
rule = term_rewriter(make_rule(Positive('a'), lambda a: a - 1))

rule(1) # ==> 0
rule(5) # ==> 0
rule((2, 3)) # ==> (0, 0)
rule(((1, 2), 3)) # ==> ((0, 0), 0)

Classes

class Expression: A class that is auto-registered with tree_map and tree_children.

Functions

bottom_up(...): Returns a rule that recursively rewrites expressions bottom up.

in_order(...): Returns a rule that applies a series of rules in order.

iterated(...): Returns a rule that iteratively applies a rule until convergence.

make_rule(...): Constructs a rewrite rule from a pattern and handler function.

rewrite_subexpressions(...): Returns a rule that applies rewrites subexpressions of an expression.

rule_list(...): Returns a rule that tries applying series of rules until one succeeds.

term_rewriter(...): Returns a rule that rewrites expressions iteratively from the bottom-up.

top_down(...): Returns a rule that recursively rewrites expressions top down.

tree_children(...): Returns the children of an expression.

tree_map(...): Shallow maps a function over the children of an expression.