A network that takes global and arm observations and outputs rewards.
Inherits From: Network
tf_agents . bandits . networks . global_and_arm_feature_network . GlobalAndArmCommonTowerNetwork (
observation_spec : tf_agents . typing . types . NestedTensorSpec
,
global_network : tf_agents . typing . types . Network
,
arm_network : tf_agents . typing . types . Network
,
common_network : tf_agents . typing . types . Network
,
name = 'GlobalAndArmCommonTowerNetwork'
) -> tf_agents . typing . types . Network
This network takes the output of the global and per-arm networks, and leads
them through a common network, that in turn outputs reward estimates.
Args
observation_spec
The observation spec for the policy that uses this
network.
global_network
The network that takes the global features as input.
arm_network
The network that takes the arm features as input.
common_network
The network that takes as input the concatenation of the
outputs of the global and the arm networks.
name
The name of this instance of GlobalAndArmCommonTowerNetwork
.
Attributes
input_tensor_spec
Returns the spec of the input to the network of type InputSpec.
layers
Get the list of all (nested) sub-layers used in this Network.
state_spec
Methods
copy
View source
copy (
** kwargs
)
Create a shallow copy of this network.
Note: Network layer weights are never copied. This method recreates
the Network
instance with the same arguments it was initialized with
(excepting any new kwargs).
Args
**kwargs
Args to override when recreating this network. Commonly
overridden args include 'name'.
Returns
A shallow copy of this network.
create_variables
View source
create_variables (
input_tensor_spec = None , ** kwargs
)
Force creation of the network's variables.
Return output specs.
Args
input_tensor_spec
(Optional). Override or provide an input tensor spec
when creating variables.
**kwargs
Other arguments to network.call()
, e.g. training=True
.
Returns
Output specs - a nested spec calculated from the outputs (excluding any
batch dimensions). If any of the output elements is a tfp Distribution
,
the associated spec entry returned is a DistributionSpec
.
Raises
ValueError
If no input_tensor_spec
is provided, and the network did
not provide one during construction.
get_initial_state
View source
get_initial_state (
batch_size = None
)
Returns an initial state usable by the network.
Args
batch_size
Tensor or constant: size of the batch dimension. Can be None
in which case not dimensions gets added.
Returns
A nested object of type self.state_spec
containing properly
initialized Tensors.
get_layer
View source
get_layer (
name = None , index = None
)
Retrieves a layer based on either its name (unique) or index.
If name
and index
are both provided, index
will take precedence.
Indices are based on order of horizontal graph traversal (bottom-up).
Args
name
String, name of layer.
index
Integer, index of layer.
Returns
A layer instance.
Raises
ValueError
In case of invalid layer name or index.
summary
View source
summary (
line_length = None , positions = None , print_fn = None
)
Prints a string summary of the network.
Args
line_length
Total length of printed lines (e.g. set this to adapt the
display to different terminal window sizes).
positions
Relative or absolute positions of log elements in each line.
If not provided, defaults to [.33, .55, .67, 1.]
.
print_fn
Print function to use. Defaults to print
. It will be called
on each line of the summary. You can set it to a custom function in
order to capture the string summary.
Raises
ValueError
if summary()
is called before the model is built.