Creates a dot product network with feedforward towers.
tf_agents.bandits.networks.global_and_arm_feature_network.create_feed_forward_dot_product_network(
observation_spec: tf_agents.typing.types.NestedTensorSpec
,
global_layers: Sequence[int],
arm_layers: Sequence[int],
activation_fn: Callable[[tf_agents.typing.types.Tensor
], tf_agents.typing.types.Tensor
] = tf.keras.activations.relu
) -> tf_agents.typing.types.Network
Args |
observation_spec
|
A nested tensor spec containing the specs for global as
well as per-arm observations.
|
global_layers
|
Iterable of ints. Specifies the layers of the global tower.
|
arm_layers
|
Iterable of ints. Specifies the layers of the arm tower. The
last element of arm_layers has to be equal to that of global_layers.
|
activation_fn
|
A keras activation, specifying the activation function used
in all layers. Defaults to relu.
|
Returns |
A dot product network that takes observations adhering observation_spec and
outputs reward estimates for every action.
|
Raises |
ValueError
|
If the last arm layer does not match the last global layer.
|