View source on GitHub |
Module for single-dispatch functions for handling state.
This module defines single-dispatch functions that are used to construct
and use Module
s. The main functions are init
, call_and_update
and spec
.
They are all single-dispatch functions, meaning they have specific implementations depending on the type of their first inputs. These implementations can be provided from outside of this library, so they act as a general API for handling state.
Methods
init
init
converts an input object into an "initializer" function, i.e. one that
takes in a random PRNGKey and a set of inputs and returns a Module
.
function.py
registers Python functions with this transformations and another
potential application is neural network layers.
call_and_update
call_and_update
executes the computation associated with an input
object, returning the output and a copy of the object with updated state.
For example, for a Module
, call_and_update(module, ...)
runs module.call_and_update
but this behavior could be defined for arbitrary
objects. For example in registrations.py
we provide some default registrations
for various Python data structures like lists and tuples.
We also provide a call
and update
function which are wrappers around
call_and_update
.
spec
spec
has the same API as init
without the PRNGKey and returns the shape
of the output that would result from calling the input object.
Example:
def f(x, init_key=None):
w = module.variable(random.normal(init_key, x.shape), name='w')
w = module.assign(w + 1., name='w')
return np.dot(w, x)
api.spec(f)(random.PRNGKey(0), np.ones(5)) # ==> ArraySpec((), np.float32)
m = api.init(f)(random.PRNGKey(0), np.ones(5))
m.variables() # ==> {'w': ...}
output, new_module = api.call_and_update(m, np.ones(5))
Classes
class ArraySpec
: Encapsulates shape and dtype of an abstract array.
Functions
init(...)
: Transforms an object into a function that initializes a module.
spec(...)
: A general purpose transformation for getting the output shape and dtype of an object.