A class used to represent networks used by TF-Agents policies and agents.
View aliases
Main aliases
tf_agents.networks.Network(
input_tensor_spec=None, state_spec=(), name=None
)
Used in the notebooks
The main differences between a TF-Agents Network and a Keras Layer include: networks keep track of their underlying layers, explicitly represent RNN-like state in inputs and outputs, and simplify variable creation and clone operations.
When calling a network net
, typically one passes data through it via:
outputs, next_state = net(observation, network_state=...)
outputs, next_state = net(observation, step_type=..., network_state=...)
outputs, next_state = net(observation) # net.call must fill an empty state
outputs, next_state = net(observation, step_type=...)
outputs, next_state = net(
observation, step_type=..., network_state=..., learning=...)
etc.
To force construction of a network's variables:
net.create_variables()
net.create_variables(input_tensor_spec=...) # To provide an input spec
net.create_variables(training=True) # Provide extra kwargs
net.create_variables(input_tensor_spec, training=True)
To create a copy of the network:
cloned_net = net.copy()
cloned_net.variables # Raises ValueError: cloned net does not share weights.
cloned_net.create_variables(...)
cloned_net.variables # Now new variables have been created.
Args | |
---|---|
input_tensor_spec
|
A nest of tf.TypeSpec representing the input
observations. Optional. If not provided, create_variables() will
fail unless a spec is provided.
|
state_spec
|
A nest of tensor_spec.TensorSpec representing the state
needed by the network. Default is () , which means no state.
|
name
|
(Optional.) A string representing the name of the network. |
Attributes | |
---|---|
input_tensor_spec
|
Returns the spec of the input to the network of type InputSpec. |
layers
|
Get the list of all (nested) sub-layers used in this Network. |
state_spec
|
Methods
copy
copy(
**kwargs
)
Create a shallow copy of this network.
Args | |
---|---|
**kwargs
|
Args to override when recreating this network. Commonly overridden args include 'name'. |
Returns | |
---|---|
A shallow copy of this network. |
create_variables
create_variables(
input_tensor_spec=None, **kwargs
)
Force creation of the network's variables.
Return output specs.
Args | |
---|---|
input_tensor_spec
|
(Optional). Override or provide an input tensor spec when creating variables. |
**kwargs
|
Other arguments to network.call() , e.g. training=True .
|
Returns | |
---|---|
Output specs - a nested spec calculated from the outputs (excluding any
batch dimensions). If any of the output elements is a tfp Distribution ,
the associated spec entry returned is a DistributionSpec .
|
Raises | |
---|---|
ValueError
|
If no input_tensor_spec is provided, and the network did
not provide one during construction.
|
get_initial_state
get_initial_state(
batch_size=None
)
Returns an initial state usable by the network.
Args | |
---|---|
batch_size
|
Tensor or constant: size of the batch dimension. Can be None in which case not dimensions gets added. |
Returns | |
---|---|
A nested object of type self.state_spec containing properly
initialized Tensors.
|
get_layer
get_layer(
name=None, index=None
)
Retrieves a layer based on either its name (unique) or index.
If name
and index
are both provided, index
will take precedence.
Indices are based on order of horizontal graph traversal (bottom-up).
Args | |
---|---|
name
|
String, name of layer. |
index
|
Integer, index of layer. |
Returns | |
---|---|
A layer instance. |
Raises | |
---|---|
ValueError
|
In case of invalid layer name or index. |
summary
summary(
line_length=None, positions=None, print_fn=None
)
Prints a string summary of the network.
Args | |
---|---|
line_length
|
Total length of printed lines (e.g. set this to adapt the display to different terminal window sizes). |
positions
|
Relative or absolute positions of log elements in each line.
If not provided, defaults to [.33, .55, .67, 1.] .
|
print_fn
|
Print function to use. Defaults to print . It will be called
on each line of the summary. You can set it to a custom function in
order to capture the string summary.
|
Raises | |
---|---|
ValueError
|
if summary() is called before the model is built.
|