View source on GitHub |
Adds measurements suitable for debugging learning processes.
tff.learning.add_debug_measurements(
aggregation_factory: _AggregationFactory
) -> _AggregationFactory
This will wrap a tff.aggregator.AggregationFactory
as a new factory that
will produce additional measurements useful for debugging learning processes.
The underlying aggregation of client values will remain unchanged.
These measurements generally concern the norm of the client updates, and the
norm of the aggregated server update. The implicit weighting will be
determined by aggregation_factory
: If this is weighted, then the debugging
measurements will use this weighting when computing averages. If it is
unweighted, the debugging measurements will use uniform weighting.
The client measurements are:
- The average Euclidean norm of client updates.
- The standard deviation of these norms.
The standard deviation we report is the square root of the unbiased variance. The server measurements are:
- The maximum entry of the aggregate client update.
- The Euclidean norm of the aggregate client update.
- The minimum entry of the aggregate client update.
In the above, an "entry" means any coordinate across all tensors in the structure. For example, suppose that we have client structures before aggregation:
- Client A:
[[-1, -3, -5], [2]]
- Client B:
[[-1, -3, 1], [0]]
If we use unweighted averaging, then the aggregate client update will be the
structure [[-1, -3, -2], [1]]
. The maximum entry is 1
, the minimum entry
is -3
, and the euclidean norm is sqrt(15)
.
Args | |
---|---|
aggregation_factory
|
A tff.aggregators.AggregationFactory . Can be weighted
or unweighted.
|
Returns | |
---|---|
A tff.aggregators.AggregationFactory .
|