View source on GitHub |
UnweightedAggregationFactory
for discrete Fourier transform.
Inherits From: UnweightedAggregationFactory
tff.aggregators.DiscreteFourierTransformFactory(
inner_agg_factory: Optional[tff.aggregators.UnweightedAggregationFactory
] = None,
num_repeats: int = 1
)
The created tff.templates.AggregationProcess
takes an input structure
and applies the randomized discrete Fourier transform (using TF's fast Fourier
transform implementation tf.signal.fft/ifft
) to each tensor in the
structure, reshaped to a rank-1 tensor in O(d*log(d))
time, where d
is the
number of elements of the tensor.
See https://en.wikipedia.org/wiki/Discrete_Fourier_transform
Specifically, for each tensor, the following operations are first performed at
tff.CLIENTS
:
- Flattens the tensor into a rank-1 tensor.
- Pads the tensor with zeros to an even number of elements (i.e. pad at most one zero).
- Packs the real valued tensor into a tensor with a complex dtype with
d/2
elements, by filling the real and imaginary values with two halves of the tensor. - Randomly rotates each coordinate of the complex tensor.
- Applies the discrete Fourier transform.
- Unpacks the complex tensor back to a real tensor with length
d
. - Normalizes the tensor by
1 / sqrt(d/2)
. Steps 4 and 5 are repeated multiple times with independent randomness, ifnum_repeats > 1
.
The resulting tensors are passed to the inner_agg_factory
. After
aggregation, at tff.SEREVR
, inverses of these steps are applied in reverse
order.
The allowed input dtypes are integers and floats. However, the dtype passed to
the inner_agg_factory
will always be a float.
Methods
create
create(
value_type: factory.ValueType
) -> tff.templates.AggregationProcess
Creates a tff.aggregators.AggregationProcess
without weights.
The provided value_type
is a non-federated tff.Type
, that is, not a
tff.FederatedType
.
The returned tff.aggregators.AggregationProcess
will be created for
aggregation of values matching value_type
placed at tff.CLIENTS
.
That is, its next
method will expect type
<S@SERVER, {value_type}@CLIENTS>
, where S
is the unplaced return type of
its initialize
method.
Args | |
---|---|
value_type
|
A non-federated tff.Type .
|
Returns | |
---|---|
A tff.templates.AggregationProcess .
|