View source on GitHub
|
Module for the harvest transformation.
This module contains a general-purpose set of tools for transforming functions with a specific side-effect mechanism into pure functions. The names of the transformations in this module are inspired by the Sow/Reap mechanism in Mathematica.
The harvest module exposes two main functions: sow and harvest. sow is
used to tag values and harvest can inject values into functions or pull out
tagged values.
harvest is a very general purpose transformation purely focused on converting
functions that have special side-effects (defined using sow) and
"functionalizing" them. Specifically, a function
f :: (x: X) -> Y has a set of defined intermediates, or Sows. This set
can be divided into intermediates you are "collecting" and intermediates you are
"injecting", or Reaps and Plants respectively. Functionalizing
f now gives you harvest(f) :: (plants: Plants, x: X) -> Tuple[Y, Reaps].
Generally, most users will not need to use harvest directly, but will use
wrappers around it.
sow
sow is the function used to tag values in a function. It takes in a single
positional argument, value, which is returned as an output, so sow outside
of a tracing context behaves like the identity function, i.e.
sow(x, ...) == x. It also takes in two mandatory keyword arguments,
tag and name. tag is a string used to namespace intermediate values in a
function. For example, some intermediates may be useful for probabilistic
programming (samples), and others may be useful to logging (summaries). The tag
enables harvest to interact with only one set of intermediates at a time.
The name is a string that describes the value you are sow-ing. Eventually,
when calling harvest on a function, the name is used as the identifier
for the intermediate value.
Finally, sow takes in an optional string keyword argument mode, which is by
default set to 'strict'. The mode of a sow describes how it behaves when
the same name appears multiple times. In "strict" mode, sow will error if the
same (tag, name) appears more than once. Another option is 'append', in
which all sows of the same name will be appended into a growing array. Finally,
there is 'clobber', where only the final sown value for a given (tag, name)
will be returned. The final optional argument for sow is key, which will
automatically be tied-in to the output of sow to introduce a fake
data-dependence. By default, it is None.
sow_cond
sow_cond is a variant of sow, that takes an additional positional argument,
pred. It supports a single mode 'cond_clobber', which is like clobber,
but sows values conditionally on pred, falling back on zeros if no sow took
place. This allows reaping values from loop iterations besides the final one.
harvest
harvest is a function transformation that augments the behaviors of sows
in the function body. Recall, that by default, sows act as identity functions
and do not affect the semantics of a function. Harvesting f produces a
function that can take advantage of sows present in its execution. harvest
is a function that takes in a function f and a string tag. harvest will
only interact with sows whose tag matches the input tag. The returned
function can interact with the sows in the function body in either of two
ways. The first is via "injection", where intermediate values in the function
values can be overridden. harvest(f) takes in an additional initial argument,
plants, a dictionary mapping names to values. Each name in plants should
correspond to a sow in f, and while running harvest(f) rather than using
the value at runtime for the sow, we substitute in the value from the plants
dictionary. The other way in which harvest(f) interacts with sows is that
if it encounters a sow whose tag matches and whose name is not in
plants, it will add the output of the sow to a dictionary mapping the sow
name to its output, called reaps. The reaps dictionary, at the end of
harvest(f)'s execution, will contain the outputs of all sows whose values
were not injected, or "planted."
The general convention is that, for any given execution of
harvest(f, tag=tag), there will be no more remaining sows of the given tag
if the function were to be reharvested, i.e. if we were to nest harvests with
the same tag harvest(harvest(f, tag='some_tag'), tag='some_tag'), the outer
harvest would have nothing to plant or to reap.
Examples:
Using sow and harvest
def f(x):
y = sow(x + 1., tag='intermediate', name='y')
return y + 1.
# Injecting, or "planting" a value for `y`.
harvest(f, tag='intermediate')({'y': 0.}, 1.) # ==> (1., {})
harvest(f, tag='intermediate')({'y': 0.}, 5.) # ==> (1., {})
# Collecting , or "reaping" the value of `y`.
harvest(f, tag='intermediate')({}, 1.) # ==> (3., {'y': 2.})
harvest(f, tag='intermediate')({}, 5.) # ==> (7., {'y': 6.})
Using reap and plant.
reap and plant are simple wrappers around harvest. reap only pulls
intermediate values without injecting, and plant only injects values without
collecting intermediate values.
def f(x):
y = sow(x + 1., tag='intermediate', name='y')
return y + 1.
# Injecting, or "planting" a value for `y`.
plant(f, tag='intermediate')({'y': 0.}, 1.) # ==> 1.
plant(f, tag='intermediate')({'y': 0.}, 5.) # ==> 1.
# Collecting , or "reaping" the value of `y`.
reap(f, tag='intermediate')(1.) # ==> {'y': 2.}
reap(f, tag='intermediate')(5.) # ==> {'y': 6.}
Sharp edges
harvesthas undefined semantics under autodifferentiation. If a function you're taking the gradient of has asow, it might produce unintuitive results when harvested. To better control gradient semantics, you can usejax.custom_jvporjax.custom_vjp. The current implementation sows primals and tangents in the JVP but ignore cotangents in the VJP. These particular semantics are subject to change.- Planting values into a
pmapis partially working. Harvest tries to plant all the values, assuming they have a leading map dimension.
Classes
class HarvestTrace: An evaluating trace that dispatches to a dynamic context.
class HarvestTracer: A HarvestTracer just encapsulates a single value.
Functions
call_and_reap(...): Transforms a function into one that additionally returns its sown values.
nest(...): Wraps a function to create a new scope for harvested values.
plant(...): Transforms a function into one that injects values in place of sown ones.
reap(...): Transforms a function into one that returns its sown values.
sow(...): Marks a value with a name and a tag.
View source on GitHub