tf_agents.bandits.agents.utils.sum_reward_weighted_observations
Calculates an update used by some Bandit algorithms.
tf_agents.bandits.agents.utils.sum_reward_weighted_observations(
r: tf_agents.typing.types.Tensor
,
x: tf_agents.typing.types.Tensor
) -> tf_agents.typing.types.Tensor
Given an observation x
and corresponding reward r
, the weigthed
observations vector (denoted b
here) should be updated as b = b + r * x
.
This function calculates the sum of weighted rewards for batched
observations x
.
Args |
r
|
a Tensor of shape [batch_size ]. This is the rewards of the batched
observations.
|
x
|
a Tensor of shape [batch_size , context_dim ]. This is the matrix
with the (batched) observations.
|
Returns |
The update that needs to be added to b . Has the same shape as b .
If the observation matrix x is empty, a zero vector is returned.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[null,null,["Last updated 2024-04-26 UTC."],[],[]]