View source on GitHub |
DqnLossInfo is stored in the extras
field of the LossInfo instance.
tf_agents.agents.dqn.dqn_agent.DqnLossInfo(
td_loss, td_error
)
Both td_loss
and td_error
have a validity mask applied to ensure that
no loss or error is calculated for episode boundaries.
td_loss: The weighted TD loss (depends on choice of loss metric and any weights passed to the DQN loss function. td_error: The unweighted TD errors, which are just calculated as:
td_error = td_targets - q_values
These can be used to update Prioritized Replay Buffer priorities.
Note that, unlike td_loss
, td_error
may contain a time dimension when
training with RNN mode. For td_loss
, this axis is averaged out.
Attributes | |
---|---|
td_loss
|
A namedtuple alias for field number 0
|
td_error
|
A namedtuple alias for field number 1
|