View source on GitHub |
Module for the propagate custom Jaxpr interpreter.
The propagate Jaxpr interpreter converts a Jaxpr to a directed graph where vars are nodes and primitives are edges. It initializes invars and outvars with Cells (an interface defined below), where a Cell encapsulates a value (or a set of values) that a node in the graph can take on, and the Cell is computed from neighboring Cells, using a set of propagation rules for each primitive.Each rule indicates whether the propagation has been completed for the given edge. If so, the propagate interpreter continues on to that primitive's neighbors in the graph. Propagation continues until there are Cells for every node, or when no further progress can be made. Finally, Cell values for all nodes in the graph are returned.
Classes
class Cell
: Base interface for objects used during propagation.
class Environment
: Keeps track of variables and their values during propagation.
class Equation
: Hashable wrapper for jax_core.Jaxprs.
Functions
propagate(...)
: Propagates cells in a Jaxpr using a set of rules.