View source on GitHub
|
Wraps AggregationFactory to report additional measurements.
tff.aggregators.add_measurements(
inner_agg_factory: tff.aggregators.AggregationFactory,
*,
client_measurement_fn: Optional[Callable[..., dict[str, Any]]] = None,
server_measurement_fn: Optional[Callable[..., dict[str, Any]]] = None
) -> tff.aggregators.AggregationFactory
The function client_measurement_fn should be a Python callable that will be
called as client_measurement_fn(value) or client_measurement_fn(value,
weight) depending on whether inner_agg_factory is weighted or unweighted.
It must be traceable by TFF and expect tff.Value objects placed at CLIENTS
as inputs, and return a collections.OrderedDict mapping string names to
tensor values placed at SERVER, which will be added to the measurement dict
produced by the inner_agg_factory.
Similarly, server_measurement_fn should be a Python callable that will be
called as server_measurement_fn(result) where result is the result (on
server) of the inner aggregation.
One or both of client_measurement_fn and server_measurement_fn must be
specified.
Returns | |
|---|---|
An AggregationFactory that reports additional measurements.
|
View source on GitHub