Aggregates and scales per example loss and regularization losses.
tf_agents.utils.common.aggregate_losses(
per_example_loss=None,
sample_weight=None,
global_batch_size=None,
regularization_loss=None
)
If global_batch_size
is given it would be used for scaling, otherwise it
would use the batch_dim of per_example_loss and number of replicas.
Args |
per_example_loss
|
Per-example loss [B] or [B, T, ...].
|
sample_weight
|
Optional weighting for each example, Tensor shaped [B] or
[B, T, ...], or a scalar float.
|
global_batch_size
|
Optional global batch size value. Defaults to (size of
first dimension of losses ) * (number of replicas).
|
regularization_loss
|
Regularization loss.
|
Returns |
An AggregatedLosses named tuple with scalar losses to optimize.
|