View source on GitHub
|
Module for JAX tracing utility functions.
Functions
get_dynamic_context(...): Returns the current active dynamic context for a trace.
get_shaped_aval(...): Converts a JAX value type into a shaped abstract value.
new_dynamic_context(...): Creates a dynamic context for a trace.
pv_like(...): Converts a JAX value type into a JAX PartialVal.
stage(...): Returns a function that stages a function to a ClosedJaxper.
trees(...): Returns a function that determines input and output pytrees from inputs.
View source on GitHub