View source on GitHub |
Aggregation Factory that finalizes and then samples the metrics.
Inherits From: UnweightedAggregationFactory
tff.learning.metrics.FinalizeThenSampleFactory(
sample_size: int = 100
)
The created tff.templates.AggregationProcess
finalizes each client's metrics
locally, and then collects metrics from at most sample_size
clients at the
tff.SERVER
. If more than sample_size
clients participating, then
sample_size
clients are sampled (by reservoir sampling algorithm);
otherwise, all clients' metrics are collected. Sampling is done in a
"per-client" manner, i.e., a client, once sampled, will contribute all its
metrics to the final result.
The collected metrics samples at tff.SERVER
has the same structure (i.e.,
same keys in a dictionary) as the client's local metrics, except that each
leaf node contains a list of scalar metric values, where each value comes from
a sampled client, e.g.,
sampled_metrics_at_server = {
'metric_a': [a1, a2, ...],
'metric_b': [b1, b2, ...],
...
}
where "a1" and "b1" are from the same client (similary for "a2" and "b2" etc).
Both "current round samples" and "total rounds samples" are returned, and
and they both contain at most metrics from sample_size
clients. Sampling is
done across the current round's participating clients (the result is
"current round samples") or across all the participating clients so far (the
result is "total rounds samples").
The next
function of the created tff.templates.AggregationProcess
takes
the state
and local unfinalized metrics reported from tff.CLIENTS
, and
returns a tff.templates.MeasuredProcessOutput
object with the following
properties:
state
: a dictionary of total rounds samples and the sampling metadata ( e.g., random values generated by the reservoir sampling algorithm).result
: a tuple of current round samples and total rounds samples.measurements
: the number of non-finite (NaN
orInf
values) leaves in the current round client values before sampling.
Example usage | |
---|---|
|
The created eval_process
can also be used in
tff.learning.programs.EvaluationManager
.
Raises | |
---|---|
TypeError
|
If any argument type mismatches. |
ValueError
|
If sample_size is not positive.
|
Methods
create
create(
metric_finalizers: Union[tff.learning.metrics.MetricFinalizersType
, tff.learning.metrics.FunctionalMetricFinalizersType
],
local_unfinalized_metrics_type: tff.types.StructWithPythonType
) -> tff.templates.AggregationProcess
Creates a tff.templates.AggregationProcess
for metrics aggregation.
Args | |
---|---|
metric_finalizers
|
Either the result of
tff.learning.models.VariableModel.metric_finalizers (an OrderedDict
of callables) or the
tff.learning.models.FunctionalModel.finalize_metrics method (a
callable that takes an OrderedDict argument). If the former, the keys
must be the same as the OrderedDict returned by
tff.learning.models.VariableModel.report_local_unfinalized_metrics . If
the later, the callable must compute over the same keyspace of the
result returned by
tff.learning.models.FunctionalModel.update_metrics_state .
|
local_unfinalized_metrics_type
|
A tff.types.StructWithPythonType (with
collections.OrderedDict as the Python container) of a client's local
unfinalized metrics.
|
Returns | |
---|---|
An instance of tff.templates.AggregationProcess .
|
Raises | |
---|---|
TypeError
|
If any argument type mismatches; if the metric finalizers mismatch the type of local unfinalized metrics; if the initial unfinalized metrics mismatch the type of local unfinalized metrics. |