It can be used to take an input of batched observations and outputs
([batch_size, num_actions, num_atoms], network's state).
The first element of the output is a batch of logits based on the distribution
called C51 from Bellemare et al., 2017 (https://arxiv.org/abs/1707.06887). The
logits are used to compute approximate probability distributions for Q-values
for each potential action, by computing the probabilities at the 51 points
(called atoms) in np.linspace(-10.0, 10.0, 51).
Args
input_tensor_spec
A tensor_spec.TensorSpec specifying the observation
spec.
action_spec
A tensor_spec.BoundedTensorSpec representing the actions.
num_atoms
The number of atoms to use in our approximate probability
distributions. Defaults to 51 to produce C51.
preprocessing_layers
(Optional.) A nest of tf.keras.layers.Layer
representing preprocessing for the different observations. All of these
layers must not be already built. For more details see the documentation
of networks.EncodingNetwork.
preprocessing_combiner
(Optional.) A keras layer that takes a flat list
of tensors and combines them. Good options include tf.keras.layers.Add
and tf.keras.layers.Concatenate(axis=-1). This layer must not be
already built. For more details see the documentation of
networks.EncodingNetwork.
conv_layer_params
Optional list of convolution layer parameters for
observations, where each item is a length-three tuple indicating
(num_units, kernel_size, stride).
fc_layer_params
Optional list of fully connected parameters for
observations, where each item is the number of units in the layer.
activation_fn
Activation function, e.g. tf.nn.relu or tf.nn.leaky_relu.
name
A string representing the name of the network.
Raises
TypeError
action_spec is not a BoundedTensorSpec.
Attributes
input_tensor_spec
Returns the spec of the input to the network of type InputSpec.
layers
Get the list of all (nested) sub-layers used in this Network.
(Optional). Override or provide an input tensor spec
when creating variables.
**kwargs
Other arguments to network.call(), e.g. training=True.
Returns
Output specs - a nested spec calculated from the outputs (excluding any
batch dimensions). If any of the output elements is a tfp Distribution,
the associated spec entry returned is a DistributionSpec.
Raises
ValueError
If no input_tensor_spec is provided, and the network did
not provide one during construction.
Total length of printed lines (e.g. set this to adapt the
display to different terminal window sizes).
positions
Relative or absolute positions of log elements in each line.
If not provided, defaults to [.33, .55, .67, 1.].
print_fn
Print function to use. Defaults to print. It will be called
on each line of the summary. You can set it to a custom function in
order to capture the string summary.