Save the date! Google I/O returns May 18-20 Register now

tff.experimental.learning.build_jax_federated_averaging_process

Constructs an iterative process that implements simple federated averaging.

batch_type An instance of tff.Type that represents the type of a single batch of data to use for training. This type should be constructed with standard Python containers (such as collections.OrderedDict) of the sort that are expected as parameters to loss_fn.
model_type An instance of tff.Type that represents the type of the model. Similarly to batch_size, this type should be constructed with standard Python containers (such as collections.OrderedDict) of the sort that are expected as parameters to loss_fn.
loss_fn A loss function for the model. Must be a Python function that takes two parameters, one of them being the model, and the other being a single batch of data (with types matching batch_type and model_type).
step_size The step size to use during training (an np.float32).

An instance of tff.templates.IterativeProcess that implements federated training in JAX.