![]() |
Client TensorFlow logic for Federated Averaging.
Inherits From: ClientDeltaFn
tff.learning.ClientFedAvg(
model: tff.learning.Model
,
optimizer: tf.keras.optimizers.Optimizer,
client_weighting: Union[ClientWeighting, ClientWeightFnType] = tff.learning.ClientWeighting.NUM_EXAMPLES,
use_experimental_simulation_loop: bool = False
)
Args | |
---|---|
model
|
A tff.learning.Model instance.
|
optimizer
|
A tf.keras.Optimizer instance.
|
client_weighting
|
A value of tff.learning.ClientWeighting that
specifies a built-in weighting method, or a callable that takes the
output of model.report_local_outputs and returns a tensor that
provides the weight in the federated average of model deltas.
|
use_experimental_simulation_loop
|
Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. |
Attributes | |
---|---|
variables
|
Returns all the variables of this object.
Note this only includes variables that are part of the state of this object, and not the model variables themselves. |
Methods
__call__
@tf.function
__call__( dataset, initial_weights )