Module: tff.learning.framework

The public API for contributors who develop federated learning algorithms.

Classes

class ClientDeltaFn: Represents a client computation that produces an update to a model.

class ClientOutput: Structure for outputs returned from clients during federated optimization.

class EnhancedModel: A wrapper around a Model that adds sanity checking and metadata helpers.

class ModelWeights: A container for the trainable and non-trainable variables of a Model.

class ServerState: Represents the state of the server carried between rounds.

Functions

build_encoded_broadcast_from_model(...): Builds StatefulBroadcastFn for weights of model returned by model_fn.

build_encoded_broadcast_process_from_model(...): Builds MeasuredProcess for weights of model returned by model_fn.

build_encoded_mean_from_model(...): Builds StatefulAggregateFn for weights of model returned by model_fn.

build_encoded_mean_process_from_model(...): Builds MeasuredProcess for weights of model returned by model_fn.

build_encoded_sum_from_model(...): Builds StatefulAggregateFn for weights of model returned by model_fn.

build_encoded_sum_process_from_model(...): Builds MeasuredProcess for weights of model returned by model_fn.

build_model_delta_optimizer_process(...): Constructs tff.templates.IterativeProcess for Federated Averaging or SGD.

build_stateless_broadcaster(...): Builds a MeasuredProcess that wraps tff.federated_broadcast.

build_stateless_mean(...): Builds a MeasuredProcess that wrapstff.federated_mean.

enhance(...): Wraps a tff.learning.Model as an EnhancedModel.

weights_type_from_model(...): Creates a tff.Type from a tff.learning.Model or callable that constructs a model.