View source on GitHub |
Contains Oryx's core transformations and functionality.
Modules
interpreters
module: Contains function transformations implemented using JAX tracing machinery.
ppl
module: Module for probabilistic programming features.
primitive
module: Module for higher order primitives.
pytree
module: Contains the Pytree class.
serialize
module: Contains logic for serializing and deserializing PytreeTypes.
state
module: Module for stateful functions.
trace_util
module: Module for JAX tracing utility functions.
Classes
class FlatPrimitive
: Contains default implementations of transformations.
class HigherOrderPrimitive
: A primitive that appears in traces through transformations.
class NonInvertibleError
: Raised by a custom inverse definition when values are unknown.
class Pytree
: Class that registers objects as Jax pytree_nodes.
Functions
call_bind(...)
: Binds a primitive to a function call.
custom_inverse(...)
: Decorates a function to enable defining a custom inverse.
inverse_and_ildj(...)
: Inverse and ILDJ function transformation.
log_prob(...)
: LogProb function transformation.
nest(...)
: Wraps a function to create a new scope for harvested values.
plant(...)
: Transforms a function into one that injects values in place of sown ones.
reap(...)
: Transforms a function into one that returns its sown values.
sow(...)
: Marks a value with a name and a tag.
tie_all(...)
: An identity function that ties arguments together in a JAX trace.
tie_in(...)
: A reimplementation of jax.tie_in
that handles pytrees.
Other Members | |
---|---|
ildj_registry |
Instance of oryx.core.interpreters.inverse.core.InverseDict
|
log_prob_registry |
|