ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

Module: oryx.core.interpreters.propagate

Module for the propagate custom Jaxpr interpreter.

The propagate Jaxpr interpreter converts a Jaxpr to a directed graph where vars are nodes and primitives are edges. It initializes invars and outvars with Cells (an interface defined below), where a Cell encapsulates a value (or a set of values) that a node in the graph can take on, and the Cell is computed from neighboring Cells, using a set of propagation rules for each primitive.Each rule indicates whether the propagation has been completed for the given edge. If so, the propagate interpreter continues on to that primitive's neighbors in the graph. Propagation continues until there are Cells for every node, or when no further progress can be made. Finally, Cell values for all nodes in the graph are returned.

Classes

class Cell: Base interface for objects used during propagation.

class Environment: Keeps track of variables and their values during propagation.

class Equation: Hashable wrapper for jax_core.Jaxprs.

Functions

propagate(...): Propagates cells in a Jaxpr using a set of rules.