Module: oryx.core.state.api

Module for single-dispatch functions for handling state.

This module defines single-dispatch functions that are used to construct and use Modules. The main functions are init, call_and_update and spec.

They are all single-dispatch functions, meaning they have specific implementations depending on the type of their first inputs. These implementations can be provided from outside of this library, so they act as a general API for handling state.



init converts an input object into an "initializer" function, i.e. one that takes in a random PRNGKey and a set of inputs and returns a Module. registers Python functions with this transformations and another potential application is neural network layers.


call_and_update executes the computation associated with an input object, returning the output and a copy of the object with updated state. For example, for a Module, call_and_update(module, ...) runs module.call_and_update but this behavior could be defined for arbitrary objects. For example in we provide some default registrations for various Python data structures like lists and tuples.

We also provide a call and update function which are wrappers around call_and_update.


spec has the same API as init without the PRNGKey and returns the shape of the output that would result from calling the input object.


def f(x, init_key=None):
  w = module.variable(random.normal(init_key, x.shape), name='w')
  w = module.assign(w + 1., name='w')
  return, x)

api.spec(f)(random.PRNGKey(0), np.ones(5))  # ==> ArraySpec((), np.float32)

m = api.init(f)(random.PRNGKey(0), np.ones(5))
m.variables()  # ==> {'w': ...}

output, new_module = api.call_and_update(m, np.ones(5))


class ArraySpec: Encapsulates shape and dtype of an abstract array.




init(...): Transforms an object into a function that initializes a module.


spec(...): A general purpose transformation for getting the output shape and dtype of an object.