Creates a common tower network with feedforward towers.
tf_agents.bandits.networks.global_and_arm_feature_network.create_feed_forward_common_tower_network(
observation_spec: tf_agents.typing.types.NestedTensorSpec
,
global_layers: Sequence[int],
arm_layers: Sequence[int],
common_layers: Sequence[int],
output_dim: int = 1,
global_preprocessing_combiner: Optional[Callable[..., tf_agents.typing.types.LossFn
]] = None,
arm_preprocessing_combiner: Optional[Callable[..., tf_agents.typing.types.LossFn
]] = None,
activation_fn: Callable[[tf_agents.typing.types.Tensor
], tf_agents.typing.types.Tensor
] = tf.keras.activations.relu,
name: Optional[str] = None
) -> tf_agents.typing.types.Network
Used in the notebooks
The network produced by this function can be used either in
GreedyRewardPredictionPolicy
, or NeuralLinUCBPolicy
.
In the former case, the network must have output_dim=1
, it is going to be an
instance of QNetwork
, and used in the policy as a reward prediction network.
In the latter case, the network will be an encoding network with its output
consumed by a reward layer or a LinUCB method. The specified output_dim
will
be the encoding dimension.
Args |
observation_spec
|
A nested tensor spec containing the specs for global as
well as per-arm observations.
|
global_layers
|
Iterable of ints. Specifies the layers of the global tower.
|
arm_layers
|
Iterable of ints. Specifies the layers of the arm tower.
|
common_layers
|
Iterable of ints. Specifies the layers of the common tower.
|
output_dim
|
The output dimension of the network. If 1, the common tower will
be a QNetwork. Otherwise, the common tower will be an encoding network
with the specified output dimension.
|
global_preprocessing_combiner
|
Preprocessing combiner for global features.
|
arm_preprocessing_combiner
|
Preprocessing combiner for the arm features.
|
activation_fn
|
A keras activation, specifying the activation function used
in all layers. Defaults to relu.
|
name
|
The network name to use. Shows up in Tensorboard losses.
|
Returns |
A network that takes observations adhering observation_spec and outputs
reward estimates for every action.
|