Calculates an update used by some Bandit algorithms.
tf_agents.bandits.agents.utils.sum_reward_weighted_observations(
r: tf_agents.typing.types.Tensor
,
x: tf_agents.typing.types.Tensor
) -> tf_agents.typing.types.Tensor
Given an observation x
and corresponding reward r
, the weigthed
observations vector (denoted b
here) should be updated as b = b + r * x
.
This function calculates the sum of weighted rewards for batched
observations x
.
Args |
r
|
a Tensor of shape [batch_size ]. This is the rewards of the batched
observations.
|
x
|
a Tensor of shape [batch_size , context_dim ]. This is the matrix
with the (batched) observations.
|
Returns |
The update that needs to be added to b . Has the same shape as b .
If the observation matrix x is empty, a zero vector is returned.
|