![]() |
A wrapper around a Model that adds sanity checking and metadata helpers.
Inherits From: Model
tff.learning.framework.EnhancedModel(
model
)
Attributes | |
---|---|
federated_output_computation
|
Performs federated aggregation of the Model's local_outputs .
This is typically used to aggregate metrics across many clients, e.g. the body of the computation might be:
N.B. It is assumed all TensorFlow computation happens in the
|
input_spec
|
The type specification of the batch_input parameter for forward_pass .
A nested structure of Similar in spirit to |
local_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
non_trainable_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
trainable_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
weights
|
Returns a tff.learning.ModelWeights .
|
Methods
forward_pass
forward_pass(
batch_input, training=True
)
Runs the forward pass and returns results.
This method should not modify any variables that are part of the model parameters, that is, variables that influence the predictions. Rather, this is done by the training loop.
However, this method may update aggregated metrics computed across calls to
forward_pass
; the final values of such metrics can be accessed via
aggregated_outputs
.
Uses in TFF:
- To implement model evaluation.
- To implement federated gradient descent and other non-Federated-Averaging algorithms, where we want the model to run the forward pass and update metrics, but there is no optimizer (we might only compute gradients on the returned loss).
- To implement Federated Averaging.
Args | |
---|---|
batch_input
|
a nested structure that matches the structure of
Model.input_spec and each tensor in batch_input satisfies
tf.TensorSpec.is_compatible_with() for the corresponding
tf.TensorSpec in Model.input_spec .
|
training
|
If True , run the training forward pass, otherwise, run in
evaluation mode. The semantics are generally the same as the training
argument to keras.Model.call ; this might e.g. influence how
dropout or batch normalization is handled.
|
Returns | |
---|---|
A BatchOutput object. The object must include the loss tensor if the
model will be trained via a gradient-based algorithm.
|
report_local_outputs
report_local_outputs()
Returns tensors representing values aggregated over forward_pass
calls.
In federated learning, the values returned by this method will typically be further aggregated across clients and made available on the server.
This method returns results from aggregating across all previous calls
to forward_pass
, most typically metrics like accuracy and loss. If needed,
we may add a clear_aggregated_outputs
method, which would likely just
run the initializers on the local_variables
.
In general, the tensors returned can be an arbitrary function of all
the tf.Variables
of this model, not just the local_variables
; for
example, this could return tensors measuring the total L2 norm of the model
(which might have been updated by training).
This method may return arbitrarily shaped tensors, not just scalar metrics. For example, it could return the average feature vector or a count of how many times each feature exceed a certain magnitude.
Returns | |
---|---|
A structure of tensors (as supported by tf.nest )
to be aggregated across clients.
|