View source on GitHub
|
UnweightedAggregationFactory for fast Walsh-Hadamard transform.
Inherits From: UnweightedAggregationFactory
tff.aggregators.HadamardTransformFactory(
inner_agg_factory: Optional[tff.aggregators.UnweightedAggregationFactory] = None,
num_repeats: int = 1
)
The created tff.templates.AggregationProcess takes an input structure
and applies the randomized fast Walsh-Hadamard transform to each tensor in the
structure, reshaped to a rank-1 tensor in O(d*log(d)) time, where d is the
number of elements of the tensor.
See https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform
Specifically, for each tensor, the following operations are first performed at
tff.CLIENTS:
- Flattens the tensor into a rank-1 tensor.
- Pads the tensor to
d_2dimensions with zeros, whered_2is the smallest power of 2 larger than or equal tod. - Multiplies the padded tensor with random
+1/-1values (i.e. flipping the signs). This is equivalent to multiplication by a diagonal matrix with Rademacher random varaibles on diagonal. - Applies the fast Walsh-Hadamard transform.
Steps 3 and 4 are repeated multiple times with independent randomness, if
num_repeats > 1.
The resulting tensors are passed to the inner_agg_factory. After
aggregation, at tff.SEREVR, inverses of these steps are applied in reverse
order.
The allowed input dtypes are integers and floats. However, the dtype passed to
the inner_agg_factory will always be a float.
Methods
create
create(
value_type: factory.ValueType
) -> tff.templates.AggregationProcess
Creates a tff.aggregators.AggregationProcess without weights.
The provided value_type is a non-federated tff.Type, that is, not a
tff.FederatedType.
The returned tff.aggregators.AggregationProcess will be created for
aggregation of values matching value_type placed at tff.CLIENTS.
That is, its next method will expect type
<S@SERVER, {value_type}@CLIENTS>, where S is the unplaced return type of
its initialize method.
| Args | |
|---|---|
value_type
|
A non-federated tff.Type.
|
| Returns | |
|---|---|
A tff.templates.AggregationProcess.
|
View source on GitHub