View source on GitHub |
Module for JAX tracing utility functions.
Functions
get_dynamic_context(...)
: Returns the current active dynamic context for a trace.
get_shaped_aval(...)
: Converts a JAX value type into a shaped abstract value.
new_dynamic_context(...)
: Creates a dynamic context for a trace.
pv_like(...)
: Converts a JAX value type into a JAX PartialVal
.
stage(...)
: Returns a function that stages a function to a ClosedJaxper.
trees(...)
: Returns a function that determines input and output pytrees from inputs.