View source on GitHub |
Aggregation Factory that performs secure summation over metrics.
Inherits From: UnweightedAggregationFactory
tff.learning.metrics.SecureSumFactory(
metric_value_ranges: Optional[UserMetricValueRangeDict] = None
)
The created tff.templates.AggregationProcess
uses the inner summation
processes created by the tff.aggregators.SecureSumFactory
to sum unfinalized
metrics from tff.CLIENTS
to tff.SERVER
.
Internally metrics are grouped by their value range and dtype, and only one secure aggregation process will be created for each group. This is an optimization for computation tracing and compiling, which can be slow when there are a large number of independent aggregations.
The initialize
function initializes the state
for each inner secure
aggregation progress. The next
function takes the state
and local
unfinalized metrics reported from tff.CLIENTS
, and returns a
tff.templates.MeasuredProcessOutput
object with the following properties:
state
: ancollections.OrderedDict
of thestate
s of the inner secure aggregation processes.result
: ancollections.OrderedDict
of secure summed unfinalized metrics.measurements
: ancollections.OrderedDict
of the measurements of inner secure aggregation processes.
Args | |
---|---|
metric_value_ranges
|
An optional collections.OrderedDict that matches
the structure of local_unfinalized_metrics_type (a value for each
tff.types.TensorType in the type tree). Each leaf in the tree should
have a 2-tuple that defines the range of expected values for that
variable in the metric. If the entire structure is None , a default
range of [0.0, 2.0**20 - 1] will be applied to integer variables and
auto-tuned bounds will be applied to float variable. Each leaf may also
be None , which will also get the default range according to the
variable value type; allowing partial user sepcialization. At runtime,
values that fall outside the ranges specified at the leaves will be
clipped to within the range.
|
Raises | |
---|---|
TypeError
|
If metric_value_ranges type mismatches.
|
Methods
create
create(
local_unfinalized_metrics_type: tff.types.StructWithPythonType
) -> tff.templates.AggregationProcess
Creates an AggregationProcess
for secure summation over metrics.
Args | |
---|---|
local_unfinalized_metrics_type
|
A tff.types.StructWithPythonType (with
collections.OrderedDict as the Python container) of a client's local
unfinalized metrics. For example, local_unfinalized_metrics could
represent the output type of
tff.learning.models.VariableModel.report_local_unfinalized_metrics() .
|
Returns | |
---|---|
An instance of tff.templates.AggregationProcess .
|
Raises | |
---|---|
TypeError
|
If any argument type mismatches. |