The main differences between a TF-Agents Network and a Keras Layer include:
networks keep track of their underlying layers, explicitly represent RNN-like
state in inputs and outputs, and simplify variable creation and clone
operations.
When calling a network net, typically one passes data through it via:
outputs,next_state=net(observation,network_state=...)outputs,next_state=net(observation,step_type=...,network_state=...)outputs,next_state=net(observation)# net.call must fill an empty stateoutputs,next_state=net(observation,step_type=...)outputs,next_state=net(observation,step_type=...,network_state=...,learning=...)
etc.
To force construction of a network's variables:
net.create_variables()net.create_variables(input_tensor_spec=...)# To provide an input specnet.create_variables(training=True)# Provide extra kwargsnet.create_variables(input_tensor_spec,training=True)
To create a copy of the network:
cloned_net=net.copy()cloned_net.variables# Raises ValueError: cloned net does not share weights.cloned_net.create_variables(...)cloned_net.variables# Now new variables have been created.
Args
input_tensor_spec
A nest of tf.TypeSpec representing the input
observations. Optional. If not provided, create_variables() will
fail unless a spec is provided.
state_spec
A nest of tensor_spec.TensorSpec representing the state
needed by the network. Default is (), which means no state.
name
(Optional.) A string representing the name of the network.
Attributes
input_tensor_spec
Returns the spec of the input to the network of type InputSpec.
layers
Get the list of all (nested) sub-layers used in this Network.
(Optional). Override or provide an input tensor spec
when creating variables.
**kwargs
Other arguments to network.call(), e.g. training=True.
Returns
Output specs - a nested spec calculated from the outputs (excluding any
batch dimensions). If any of the output elements is a tfp Distribution,
the associated spec entry returned is a DistributionSpec.
Raises
ValueError
If no input_tensor_spec is provided, and the network did
not provide one during construction.
Total length of printed lines (e.g. set this to adapt the
display to different terminal window sizes).
positions
Relative or absolute positions of log elements in each line.
If not provided, defaults to [.33, .55, .67, 1.].
print_fn
Print function to use. Defaults to print. It will be called
on each line of the summary. You can set it to a custom function in
order to capture the string summary.
[null,null,["Last updated 2024-04-26 UTC."],[],[],null,["# tf_agents.networks.Network\n\n\u003cbr /\u003e\n\n|--------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/networks/network.py#L110-L491) |\n\nA class used to represent networks used by TF-Agents policies and agents.\n\n#### View aliases\n\n\n**Main aliases**\n\n[`tf_agents.networks.network.Network`](https://www.tensorflow.org/agents/api_docs/python/tf_agents/networks/Network)\n\n\u003cbr /\u003e\n\n tf_agents.networks.Network(\n input_tensor_spec=None, state_spec=(), name=None\n )\n\n### Used in the notebooks\n\n| Used in the tutorials |\n|-------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| - [Policies](https://www.tensorflow.org/agents/tutorials/3_policies_tutorial) - [Networks](https://www.tensorflow.org/agents/tutorials/8_networks_tutorial) |\n\nThe main differences between a TF-Agents Network and a Keras Layer include:\nnetworks keep track of their underlying layers, explicitly represent RNN-like\nstate in inputs and outputs, and simplify variable creation and clone\noperations.\n\nWhen calling a network `net`, typically one passes data through it via: \n\n outputs, next_state = net(observation, network_state=...)\n outputs, next_state = net(observation, step_type=..., network_state=...)\n outputs, next_state = net(observation) # net.call must fill an empty state\n outputs, next_state = net(observation, step_type=...)\n outputs, next_state = net(\n observation, step_type=..., network_state=..., learning=...)\n\netc.\n\nTo force construction of a network's variables: \n\n net.create_variables()\n net.create_variables(input_tensor_spec=...) # To provide an input spec\n net.create_variables(training=True) # Provide extra kwargs\n net.create_variables(input_tensor_spec, training=True)\n\nTo create a copy of the network: \n\n cloned_net = net.copy()\n cloned_net.variables # Raises ValueError: cloned net does not share weights.\n cloned_net.create_variables(...)\n cloned_net.variables # Now new variables have been created.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|---------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `input_tensor_spec` | A nest of [`tf.TypeSpec`](https://www.tensorflow.org/api_docs/python/tf/TypeSpec) representing the input observations. Optional. If not provided, `create_variables()` will fail unless a spec is provided. |\n| `state_spec` | A nest of `tensor_spec.TensorSpec` representing the state needed by the network. Default is `()`, which means no state. |\n| `name` | (Optional.) A string representing the name of the network. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|---------------------|-----------------------------------------------------------------|\n| `input_tensor_spec` | Returns the spec of the input to the network of type InputSpec. |\n| `layers` | Get the list of all (nested) sub-layers used in this Network. |\n| `state_spec` | \u003cbr /\u003e \u003cbr /\u003e |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `copy`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/networks/network.py#L341-L355) \n\n copy(\n **kwargs\n )\n\nCreate a shallow copy of this network.\n| **Note:** Network layer weights are *never* copied. This method recreates the `Network` instance with the same arguments it was initialized with (excepting any new kwargs).\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|------------|-----------------------------------------------------------------------------------------|\n| `**kwargs` | Args to override when recreating this network. Commonly overridden args include 'name'. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A shallow copy of this network. ||\n\n\u003cbr /\u003e\n\n### `create_variables`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/networks/network.py#L188-L252) \n\n create_variables(\n input_tensor_spec=None, **kwargs\n )\n\nForce creation of the network's variables.\n\nReturn output specs.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|---------------------|-------------------------------------------------------------------------------|\n| `input_tensor_spec` | (Optional). Override or provide an input tensor spec when creating variables. |\n| `**kwargs` | Other arguments to `network.call()`, e.g. `training=True`. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| Output specs - a nested spec calculated from the outputs (excluding any batch dimensions). If any of the output elements is a tfp `Distribution`, the associated spec entry returned is a `DistributionSpec`. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ||\n|--------------|-------------------------------------------------------------------------------------------------|\n| `ValueError` | If no `input_tensor_spec` is provided, and the network did not provide one during construction. |\n\n\u003cbr /\u003e\n\n### `get_initial_state`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/networks/network.py#L465-L476) \n\n get_initial_state(\n batch_size=None\n )\n\nReturns an initial state usable by the network.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|--------------|-------------------------------------------------------------------------------------------------------|\n| `batch_size` | Tensor or constant: size of the batch dimension. Can be None in which case not dimensions gets added. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A nested object of type `self.state_spec` containing properly initialized Tensors. ||\n\n\u003cbr /\u003e\n\n### `get_layer`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/networks/network.py#L279-L314) \n\n get_layer(\n name=None, index=None\n )\n\nRetrieves a layer based on either its name (unique) or index.\n\nIf `name` and `index` are both provided, `index` will take precedence.\nIndices are based on order of horizontal graph traversal (bottom-up).\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|---------|--------------------------|\n| `name` | String, name of layer. |\n| `index` | Integer, index of layer. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A layer instance. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ||\n|--------------|-----------------------------------------|\n| `ValueError` | In case of invalid layer name or index. |\n\n\u003cbr /\u003e\n\n### `summary`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/networks/network.py#L316-L339) \n\n summary(\n line_length=None, positions=None, print_fn=None\n )\n\nPrints a string summary of the network.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `line_length` | Total length of printed lines (e.g. set this to adapt the display to different terminal window sizes). |\n| `positions` | Relative or absolute positions of log elements in each line. If not provided, defaults to `[.33, .55, .67, 1.]`. |\n| `print_fn` | Print function to use. Defaults to `print`. It will be called on each line of the summary. You can set it to a custom function in order to capture the string summary. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ||\n|--------------|-----------------------------------------------------|\n| `ValueError` | if `summary()` is called before the model is built. |\n\n\u003cbr /\u003e"]]