View source on GitHub |
Base class for neural network layers.
oryx.experimental.nn.Layer(
layer_params, name=None
)
A Layer
is a subclass of Module
with some additional functionality. Like
Module
s, Layer
s have a variables()
method that returns a dictionary
mapping names to state values. It also has a call_and_update
function that
returns the output of a computation and a new Layer
with updated state.
Underneath the hood, Layers
do a couple extra things beyond Module
s.
Attributes | |
---|---|
info
|
Returns the info for this Layer .
|
params
|
Returns the parameters of this Layer .
|
state
|
Returns the state of this Layer .
|
Methods
call
call(
*args, **kwargs
)
Calls the Layer
's call_and_update
and returns the first result.
call_and_update
call_and_update(
*args, rng=None, **kwargs
)
Uses the layer_cau
primitive to call `self._call_and_update.
flatten
flatten()
Converts the Layer to a tuple suitable for PyTree.
initialize
@classmethod
@abc.abstractmethod
initialize( init_key, in_spec )
Initializes a Layer
from an init_key
and input specification.
new
@classmethod
new( layer_params, name=None )
Creates Layer given a LayerParams namedtuple.
Args | |
---|---|
layer_params
|
LayerParams namedtuple that defines the Layer. |
name
|
a string name for the Layer. |
Returns | |
---|---|
A Layer object.
|
replace
replace(
params=None, state=None, info=None
)
Returns a copy of the layer with replaced properties.
unflatten
@classmethod
unflatten( data, xs )
Reconstruct the Layer from a flattened version.
update
update(
*args, **kwargs
)
Calls the Layer
's call_and_update
and returns the second result.
variables
variables()
Returns the variables dictionary for this Layer
.
__call__
__call__(
*args, **kwargs
) -> Any
Emulates a regular function call.
A Module
's dunder call will ensure state is updated after the function
call by calling assign
on the updated state before returning the output of
the function.
Args | |
---|---|
*args
|
The arguments to the module. |
**kwargs
|
The keyword arguments to the module. |
Returns | |
---|---|
The output of the module. |