View source on GitHub |
Reimplementation of a subset of the optax library using Oryx's state system.
This module is an advanced example of how to write stateful code using Oryx. For a more complete and supported optimizers package that includes additional transformations and other features, please take a look at Optax.
Functions
add_noise(...)
: Returns a function that adds noise to updates.
apply_every(...)
: Returns a function that accumulates updates and applies them all at once.
chain(...)
: Composes update functions together serially.
clip_by_global_norm(...)
: Returns a function that clips updates to a provided max norm.
optimize(...)
: Runs several iterations of optimization and returns the result.
scale_by_adam(...)
: Scales updates according to Adam update rules.
scale_by_rms(...)
: Returns a function that scales updates by the RMS of the updates.
scale_by_schedule(...)
: Returns a function that scales updates according to an input schedule.
scale_by_stddev(...)
: Returns a function that scales updates by their standard deviation.
trace(...)
: Returns a function that combines updates with a running state.