ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tff.learning.add_debug_measurements

Adds measurements suitable for debugging learning processes.

This will wrap a tff.aggregator.AggregationFactory as a new factory that will produce additional measurements useful for debugging learning processes. The underlying aggregation of client values will remain unchanged.

These measurements generally concern the norm of the client updates, and the norm of the aggregated server update. The implicit weighting will be determined by aggregation_factory: If this is weighted, then the debugging measurements will use this weighting when computing averages. If it is unweighted, the debugging measurements will use uniform weighting.

The client measurements are:

  • The average Euclidean norm of client updates.
  • The standard deviation of these norms.

The standard deviation we report is the square root of the unbiased variance. The server measurements are:

  • The maximum entry of the aggregate client update.
  • The Euclidean norm of the aggregate client update.
  • The minimum entry of the aggregate client update.

In the above, an "entry" means any coordinate across all tensors in the structure. For example, suppose that we have client structures before aggregation:

  • Client A: [[-1, -3, -5], [2]]
  • Client B: [[-1, -3, 1], [0]]

If we use unweighted averaging, then the aggregate client update will be the structure [[-1, -3, -2], [1]]. The maximum entry is 1, the minimum entry is -3, and the euclidean norm is sqrt(15).

aggregation_factory A tff.aggregators.AggregationFactory. Can be weighted or unweighted.

A tff.aggregators.AggregationFactory.