ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

Module: oryx.core.state.api

Module for single-dispatch functions for handling state.

This module defines single-dispatch functions that are used to construct and use Modules. The main functions are init, call_and_update and spec.

They are all single-dispatch functions, meaning they have specific implementations depending on the type of their first inputs. These implementations can be provided from outside of this library, so they act as a general API for handling state.

Methods

init

init converts an input object into an "initializer" function, i.e. one that takes in a random PRNGKey and a set of inputs and returns a Module. function.py registers Python functions with this transformations and another potential application is neural network layers.

call_and_update

call_and_update executes the computation associated with an input object, returning the output and a copy of the object with updated state. For example, for a Module, call_and_update(module, ...) runs module.call_and_update but this behavior could be defined for arbitrary objects. For example in registrations.py we provide some default registrations for various Python data structures like lists and tuples.

We also provide a call and update function which are wrappers around call_and_update.

spec

spec has the same API as init without the PRNGKey and returns the shape of the output that would result from calling the input object.

Example:

def f(x, init_key=None):
  w = module.variable(random.normal(init_key, x.shape), name='w')
  w = module.assign(w + 1., name='w')
  return np.dot(w, x)

api.spec(f)(random.PRNGKey(0), np.ones(5))  # ==> ArraySpec((), np.float32)

m = api.init(f)(random.PRNGKey(0), np.ones(5))
m.variables()  # ==> {'w': ...}

output, new_module = api.call_and_update(m, np.ones(5))

Classes

class ArraySpec: Encapsulates shape and dtype of an abstract array.

Functions

call(...)

call_and_update(...)

init(...): Transforms an object into a function that initializes a module.

make_array_spec(...)

spec(...): A general purpose transformation for getting the output shape and dtype of an object.

update(...)