ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Returns a function transformation that applies a provided set of handlers.

handlers A dict that maps JAX primitives to callback functions that take in (state, *args) and return (output, new_state). When running the the transformed function, the execution of primitives in handlers will be delegated to the callback functions rather than their default execution rules.

A function transformation that applies the handlers to an input function. The transformation takes in an input function and returns a transformed function that takes an additional initial state argument and returns an additional output state.