View source on GitHub |
Libraries for interacting with MapReduce-like backends.
This package contains libraries for using TFF in backend systems that offer
MapReduce-like capabilities, i.e., systems that can perform parallel processing
on a set of clients, and then aggregate the results of such processing on the
server. Systems of this type do not support the full expressiveness of TFF, but
they are common enough in practice to warrant a dedicated set of libraries, and
many examples of TFF computations, including those constructed by
tff.learning
, can be compiled by TFF into a form that can be deployed on such
systems.
This package defines a few data structures: BroadcastForm
, MapReduceForm
,
and DistributeAggregateForm
. DistributeAggregateForm
will eventually replace
BroadcastForm
and MapReduceForm
.
The type signature of a TFF computation round_comp
that can be converted into
MapReduceForm
or DistributeAggregateForm
is as follows:
(<S@SERVER,{D}@CLIENTS> -> <S@SERVER,X@SERVER>)
The server state is the first component of the input, and the computation returns updated server state as the first component of the output. Since the set of clients involved in a federated computation will (often) vary from round to round, the server state is sometimes needed to connect subsequent rounds into a single contiguous logical sequence. If there is no need for server state, the input/output state should be modeled as an empty tuple. The computation can also take client-side data as input, and can produce results on server side in addition to state intended to be passed to the next round. As is the case for the server state, if this is undesired it should be modeled as an empty tuple.
The above type signature involves the following abstract types:
S
is the type of the state that is passed at the server between rounds of processing. For example, in the context of federated training, the server state would typically include the weights of the model being trained. The weights would be updated in each round as the model is trained on more and more of the clients' data, and hence the server state would evolve as well.D
represents the type of per-client units of data that serve as the input to the computation. Often, this would be a sequence type, i.e., a dataset in TensorFlow's parlance, although strictly speaking this does not have to always be the case.X
represents the type of server-side outputs generated by the server after each round.
One can think of the process based on this representation as being equivalent to the following pseudocode loop:
client_data = ...
server_state = initialize_comp()
while True:
server_state, server_outputs = round_comp(server_state, client_data)
In MapReduceForm
, the logic of round_comp
is factored into seven main
components that are all TensorFlow functions: prepare
, work
, zero
,
accumulate
, merge
, report
, and update
. There are also additional
secure_sum_bitwidth
, secure_sum_max_input
, and secure_modular_sum_modulus
TensorFlow function components that specify runtime parameters for
federated_secure_sum_*
intrinsics). The pseudocode below uses common
syntactic shortcuts (such as implicit zipping) when showing how an instance of
MapReduceForm
maps to a single federated round.
@tff.federated_computation
def round_comp(server_state, client_data):
# The server prepares an input to be broadcast to all clients that controls
# what will happen in this round.
client_input = (
tff.federated_broadcast(tff.federated_map(prepare, server_state)))
# The clients all independently do local work and produce updates, plus the
# optional client-side outputs.
client_updates = tff.federated_map(work, [client_data, client_input])
# `client_updates` is a 4-tuple whose elements are passed to the following
# intrinsics:
# 1. `federated_aggregate`
# 2. `federated_secure_sum_bitwidth`
# 3. `federated_secure_sum`
# 4. `federated_secure_modular_sum`
# The intrinsics aggregate the updates across the system into a single global
# update at the server.
simple_agg = tff.federated_aggregate(
client_updates[0], zero(), accumulate, merge, report))
secure_aggs = [
tff.federated_secure_sum_bitwidth(client_updates[1], bitwidth()),
tff.federated_secure_sum(client_updates[2], max_input()),
tff.backends.mapreduce.federated_secure_modular_sum(
client_updates[3], modulus()),
]
global_update = [simple_agg] + secure_aggs
# Finally, the server produces a new state as well as server-side output to
# emit from this round.
new_server_state, server_output = (
tff.federated_map(update, [server_state, global_update]))
# The updated server state, server- and client-side outputs are returned as
# results of this round.
return new_server_state, server_output
Details on the seven main pieces of pure TensorFlow logic in the MapReduceForm
are below. Please also consult the documentation for related federated operators
for more detail (particularly the tff.federated_aggregate()
, as several of the
components below correspond directly to the parameters of that operator).
prepare
represents the preparatory steps taken by the server to generate inputs that will be broadcast to the clients and that, together with the client data, will drive the client-side work in this round. It takes the initial state of the server, and produces the input for use by the clients. Its type signature is(S -> C)
.work
represents the totality of client-side processing, again all as a single section of TensorFlow code. It takes a tuple of client data and client input that was broadcasted by the server, and returns a two-tuple containing the client update to be aggregated (across all the clients). The first index of this two-tuple will be passed to an aggregation parameterized by the blocks of TensorFlow below (zero
,accumulate
,merge
, andreport
), and the second index will be passed tofederated_secure_sum_bitwidth
. Its type signature is(<D,C> -> <U,V>)
.zero
is the TensorFlow computation that produces the initial state of accumulators that are used to combine updates collected from subsets of the client population. In some systems, all accumulation may happen at the server, but for scalability reasons, it is often desirable to structure aggregation in multiple tiers. Its type signature isA
, or when represented as atff.Computation
in Python,( -> A)
.accumulate
is the TensorFlow computation that updates the state of an update accumulator (initialized withzero
above) with a single client's update. Its type signature is(<A,U> -> A)
. Typically, a single accumulator would be used to combine the updates from multiple clients, but this does not have to be the case (it's up to the target deployment platform to choose how to use this logic in a particular deployment scenario).merge
is the TensorFlow computation that merges two accumulators holding the results of aggregation over two disjoint subsets of clients. Its type signature is(<A,A> -> A)
.report
is the TensorFlow computation that transforms the state of the top-most accumulator (after accumulating updates from all clients and merging all the resulting accumulators into a single one at the top level of the system hierarchy) into the final result of aggregation. Its type signature is(A -> R)
.update
is the TensorFlow computation that applies the aggregate of all clients' updates (the output ofreport
), also referred to above as the global update, to the server state, to produce a new server state to feed into the next round, and that additionally outputs a server-side output, to be reported externally as one of the results of this round. In federated learning scenarios, the server-side outputs might include things like loss and accuracy metrics, and the server state to be carried over, as noted above, may include the model weights to be trained further in a subsequent round. The type signature of this computation is(<S,R> -> <S,X>)
.
The above TensorFlow computations' type signatures involves the following abstract types in addition to those defined earlier:
C
is the type of the inputs for the clients, to be supplied by the server at the beginning of each round (or an empty tuple if not needed).U
is the type of the per-client update to be produced in each round and fed into the cross-client federated aggregation protocol.V
is the type of the per-client update to be produced in each round and fed into the cross-client secure aggregation protocol.A
is the type of the accumulators used to combine updates from subsets of clients.R
is the type of the final result of aggregating all client updates, the global update to be incorporated into the server state at the end of the round (and to produce the server-side output).
In DistributeAggregateForm
, the logic of round_comp
is factored into five
main components that are all TFF Lambda Computations (as defined in
computation.proto
): server_prepare
, server_to_client_broadcast
,
client_work
, client_to_server_aggregation
, and server_result
. The
pseudocode below shows how an instance of DistributeAggregateForm
maps to a
single federated round.
@tff.federated_computation
def round_comp(server_state, client_data):
# The server prepares an input to be broadcast to all clients and generates
# a temporary state that may be used by later parts of the computation.
context_at_server, post_client_work_state = server_prepare(server_state)
# Broadcast context_at_server to the clients.
context_at_clients = server_to_client_broadcast(context_at_server)
# The clients all independently do local work and produce updates.
work_at_clients = client_work(client_data, context_at_clients)
# Aggregate the client updates.
intermediate_result_at_server = client_to_server_aggregation(
post_client_work_state, work_at_clients)
# Finally, the server produces a new state as well as server-side output to
# emit from this round.
new_server_state, server_output = server_result(
post_client_work_state, intermediate_result_at_server)
# The updated server state and server-side output are returned as results of
# this round.
return new_server_state, server_output
Details on the five components of DistributeAggregateForm
are below.
server_prepare
represents the preparatory steps taken by the server to generate 1) inputs that will be broadcast to the clients and 2) a temporary state that may be needed by theclient_to_server_aggregation
andserver_result
components. The entire lambda may contain only SERVER placements, and its type signature is(S -> <B_I, T>)
.server_to_client_broadcast
represents the broadcast of data from the server to the clients. It contains a block of locals that are exclusively intrinsics with IntrinsicDef.broadcast_kind and that depend only on theserver_to_client_broadcast
args. It returns the results of these intrinsics in the order they are computed. Its type signature is(B_I -> B_O)
.client_work
represents the totality of client-side processing. It takes a tuple of client data and client input that was broadcasted by the server, and returns the client update to be aggregated (across all the clients). The entire lambda may contain only CLIENTS placements, and its type signature is(<D, B_O> -> A_I)
.client_to_server_aggregation
represents the aggregation of data from the clients to the server. It may incorporate the temporary state that was generated by theserver_prepare
component to set dynamic aggregation parameters. It contains a block of locals that are exclusively intrinsics with IntrinsicDef.aggregation_kind and that depend only on theclient_to_server_aggregation
args. It returns the results of these intrinsics in the order they are computed. Its type signature is(<T, A_I> -> A_O)
.server_result
represents the post-processing steps taken by the server. It may depend on the temporary state that was generated by theserver_prepare
component, and the data that was aggregated from the clients. It will generate a new server state to feed into the next round and an additional server-side output to be reported externally as one of the results of this round. In federated learning scenarios, the server-side outputs might include things like loss and accuracy metrics, and the server state to be carried over may include the model weights. The entire lambda may contain only SERVER placements, and its type signature is(<T, A_O> -> <S, X>)
.
The above TFF Lambda Computations' type signatures involves the following abstract types in addition to those defined earlier:
B_I
is the type of the broadcast inputs.B_O
is the type of the broadcast outputs.A_I
is the type of the aggregation inputs.A_O
is the type of the aggregation outputs.T
is the type of the temporary state.
Classes
class BroadcastForm
: Standardized representation of server-to-client logic.
class DistributeAggregateForm
: Standard representation of logic deployable to a federated learning system.
class MapReduceForm
: Standardized representation of logic deployable to MapReduce-like systems.
Functions
check_computation_compatible_with_map_reduce_form(...)
: Tests compatibility with tff.backends.mapreduce.MapReduceForm
.
consolidate_and_extract_local_processing(...)
: Consolidates all the local processing in comp
.
federated_secure_modular_sum(...)
: Computes a modular sum at tff.SERVER
of a value
from tff.CLIENTS
.
get_broadcast_form_for_computation(...)
: Constructs tff.backends.mapreduce.BroadcastForm
given a computation.
get_computation_for_broadcast_form(...)
: Creates tff.Computation
from a broadcast form.
get_computation_for_distribute_aggregate_form(...)
: Creates tff.Computation
from a DistributeAggregate form.
get_computation_for_map_reduce_form(...)
: Creates tff.Computation
from a MapReduce form.
get_distribute_aggregate_form_for_computation(...)
: Constructs DistributeAggregateForm
for a computation.
get_map_reduce_form_for_computation(...)
: Constructs tff.backends.mapreduce.MapReduceForm
for a computation.
get_state_initialization_computation(...)
: Validates and transforms a computation to generate state.
parse_tff_to_tf(...)
: Parses TFF construct comp
into TensorFlow construct.