Builds finalizer that applies a step of an optimizer.
tff.learning.templates.build_apply_optimizer_finalizer(
optimizer_fn: tff.learning.optimizers.Optimizer
,
model_weights_type: tff.types.StructType
,
should_reject_update: Callable[[Any, Any], tuple[Union[bool, tf.Tensor], Optional[_MeasurementsType]]
] = tff.learning.templates.reject_non_finite_update
)
Used in the notebooks
The provided model_weights_type
must be a non-federated tff.Type
with the
tff.learning.models.ModelWeights
container.
The 2nd input argument of the created FinalizerProcess.next
expects a value
matching model_weights_type
and its 3rd argument expects value matching
model_weights_type.trainable
. The optimizer
will be applied to the
trainable model weights only, leaving non_trainable weights unmodified.
The state of the process is the state of the optimizer
and the process
returns empty measurements.
Args |
optimizer_fn
|
A tff.learning.optimizers.Optimizer . This optimizer is used
to apply client updates to the server model.
|
model_weights_type
|
A non-federated tff.Type of the model weights to be
optimized, which must have a tff.learning.models.ModelWeights container.
|
should_reject_update
|
A callable that takes the optimizer state and the
model weights update, and returns a boolean or a bool tensor indicating if
the model weights update should be rejected and an OrderedDict of
measurements. If the model weights update is reject, we will fall back to
the previous round's optimizer state and model weight, this is a no-op
otherwise. The default function is reject_non_finite_update which checks
if there is any non-finite value in the model update and returns the
results.
|
Returns |
A FinalizerProcess that applies the optimizer .
|
Raises |
TypeError
|
If value_type does not have a
tff.learning.model.sModelWeights
Python container, or contains a tff.types.FederatedType .
|