tff.learning.metrics.FinalizeThenSampleFactory

Aggregation Factory that finalizes and then samples the metrics.

Inherits From: UnweightedAggregationFactory

The created tff.templates.AggregationProcess finalizes each client's metrics locally, and then collects metrics from at most sample_size clients at the tff.SERVER. If more than sample_size clients participating, then sample_size clients are sampled (by reservoir sampling algorithm); otherwise, all clients' metrics are collected. Sampling is done in a "per-client" manner, i.e., a client, once sampled, will contribute all its metrics to the final result.

The collected metrics samples at tff.SERVER has the same structure (i.e., same keys in a dictionary) as the client's local metrics, except that each leaf node contains a list of scalar metric values, where each value comes from a sampled client, e.g.,

  sampled_metrics_at_server = {
      'metric_a': [a1, a2, ...],
      'metric_b': [b1, b2, ...],
      ...
  }

where "a1" and "b1" are from the same client (similary for "a2" and "b2" etc).

Both "current round samples" and "total rounds samples" are returned, and and they both contain at most metrics from sample_size clients. Sampling is done across the current round's participating clients (the result is "current round samples") or across all the participating clients so far (the result is "total rounds samples").

The next function of the created tff.templates.AggregationProcess takes the state and local unfinalized metrics reported from tff.CLIENTS, and returns a tff.templates.MeasuredProcessOutput object with the following properties:

  • state: a dictionary of total rounds samples and the sampling metadata ( e.g., random values generated by the reservoir sampling algorithm).
  • result: a tuple of current round samples and total rounds samples.
  • measurements: the number of non-finite (NaN or Inf values) leaves in the current round client values before sampling.

sample_process = FinalizeThenSampleFactory(sample_size).create(
    metric_finalizers, local_unfinalized_metrics_type)
eval_process = tff.learning.algorithms.build_fed_eval(
    model_fn=..., metrics_aggregation_process=sample_process, ...)
state = eval_process.initialize()
for i in range(num_rounds):
  output = eval_process.next(state, client_data_at_round_i)
  state = output.state
  current_round_samples, total_rounds_samples = output.result

The created eval_process can also be used in tff.learning.programs.EvaluationManager.

sample_size An integer specifying the number of clients sampled (by reservoir sampling algorithm). Metrics from the sampled clients are collected at the server, and this sample_size applies to current round and total rounds samples (see the class documentation for details). Default value is 100.

TypeError If any argument type mismatches.
ValueError If sample_size is not positive.

Methods

create

View source

Creates a tff.templates.AggregationProcess for metrics aggregation.

Args
metric_finalizers Either the result of tff.learning.models.VariableModel.metric_finalizers (an OrderedDict of callables) or the tff.learning.models.FunctionalModel.finalize_metrics method (a callable that takes an OrderedDict argument). If the former, the keys must be the same as the OrderedDict returned by tff.learning.models.VariableModel.report_local_unfinalized_metrics. If the later, the callable must compute over the same keyspace of the result returned by tff.learning.models.FunctionalModel.update_metrics_state.
local_unfinalized_metrics_type A tff.types.StructWithPythonType (with collections.OrderedDict as the Python container) of a client's local unfinalized metrics.

Returns
An instance of tff.templates.AggregationProcess.

Raises
TypeError If any argument type mismatches; if the metric finalizers mismatch the type of local unfinalized metrics; if the initial unfinalized metrics mismatch the type of local unfinalized metrics.