View source on GitHub |
Represents a reconstruction model for use in Tensorflow Federated.
Used in the notebooks
Used in the tutorials |
---|
tff.learning.models.ReconstructionModel
s are used to train models that
reconstruct a set of their variables on device, never sharing those variables
with the
server.
Each tff.learning.models.ReconstructionModel
will work on a set of
tf.Variables
, and each method should be a computation that can be
implemented as a tf.function
; this implies the class should essentially be
stateless from a Python perspective, as each method will generally only be
traced once (per set of arguments) to create the corresponding TensorFlow
graph functions. Thus, tff.learning.models.ReconstructionModel
instances
should behave as expected in both eager and graph (TF 1.0) usage.
In general, tf.Variables
may be either:
- Weights, the variables needed to make predictions with the model.
- Local variables, e.g. to accumulate aggregated metrics across calls to forward_pass.
The weights can be broken down into:
- Global variables: Variables that are allowed to be aggregated on the server.
- Local variables: Variables that cannot leave the device.
Furthermore, both of these types of variables can be:
- Trainable variables: These can and should be trained using gradient-based methods.
- Non-trainable variables: Could include fixed pre-trained layers or static model data.
These variables are provided via:
global_trainable_variables
global_non_trainable_variables
local_trainable_variables
local_non_trainable_variables
properties, and must be initialized by the user of the
tff.learning.models.ReconstructionModel
.
While training a reconstruction model, global trainable variables will generally be provided by the server. Local trainable variables will then be reconstructed locally. Updates to the global trainable variables will be sent back to the server. Local variables are not transmitted.
All tf.Variables
should be introduced in __init__
; this could move to a
build
method more inline with Keras (see
https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) in
the future.
Attributes | |
---|---|
global_non_trainable_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
global_trainable_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
input_spec
|
The type specification of the batch_input parameter for forward_pass .
A nested structure of |
local_non_trainable_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
local_trainable_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
Methods
build_dataset_split_fn
@classmethod
build_dataset_split_fn( recon_epochs: int = 1, recon_steps_max: Optional[int] = None, post_recon_epochs: int = 1, post_recon_steps_max: Optional[int] = None, split_dataset: bool = False ) ->
tff.learning.models.ReconstructionDatasetSplitFn
Builds a ReconstructionDatasetSplitFn
for training/evaluation.
The returned ReconstructionDatasetSplitFn
parameterizes training and
evaluation computations and enables reconstruction for multiple local
epochs, multiple epochs of post-reconstruction training, limiting the number
of steps for both stages, and splitting client datasets into disjoint halves
for each stage.
Note that the returned function is used during both training and evaluation: during training, "post-reconstruction" refers to training of global variables, and during evaluation, it refers to calculation of metrics using reconstructed local variables and fixed global variables.
Args | |
---|---|
recon_epochs
|
The integer number of iterations over the dataset to make during reconstruction. |
recon_steps_max
|
If not None, the integer maximum number of steps
(batches) to iterate through during reconstruction. This maximum number
of steps is across all reconstruction iterations, i.e. it is applied
after recon_epochs . If None, this has no effect.
|
post_recon_epochs
|
The integer constant number of iterations to make over client data after reconstruction. |
post_recon_steps_max
|
If not None, the integer maximum number of steps
(batches) to iterate through after reconstruction. This maximum number
of steps is across all post-reconstruction iterations, i.e. it is
applied after post_recon_epochs . If None, this has no effect.
|
split_dataset
|
If True, splits client_dataset in half for each user,
using even-indexed entries in reconstruction and odd-indexed entries
after reconstruction. If False, client_dataset is used for both
reconstruction and post-reconstruction, with the above arguments
applied. If True, splitting requires that mupltiple iterations through
the dataset yield the same ordering. For example if
client_dataset.shuffle(reshuffle_each_iteration=True) has been called,
then the split datasets may have overlap. If True, note that the dataset
should have more than one batch for reasonable results, since the
splitting does not occur within batches.
|
Returns | |
---|---|
A SplitDatasetFn .
|
forward_pass
@abc.abstractmethod
forward_pass( batch_input, training=True )
Runs the forward pass and returns results.
This method should not modify any variables that are part of the model parameters, that is, variables that influence the predictions. Rather, this is done by the training loop.
Args | |
---|---|
batch_input
|
A nested structure that matches the structure of
Model.input_spec and each tensor in batch_input satisfies
tf.TensorSpec.is_compatible_with() for the corresponding
tf.TensorSpec in Model.input_spec .
|
training
|
If True , run the training forward pass, otherwise, run in
evaluation mode. The semantics are generally the same as the training
argument to keras.Model.__call__ ; this might e.g. influence how
dropout or batch normalization is handled.
|
Returns | |
---|---|
A ReconstructionBatchOutput object.
|
from_keras_model_and_layers
@classmethod
from_keras_model_and_layers( keras_model: tf.keras.Model, *, global_layers: Iterable[tf.keras.layers.Layer], local_layers: Iterable[tf.keras.layers.Layer], input_spec: Any ) -> 'ReconstructionModel'
Builds a tff.learning.models.ReconstructionModel
from a tf.keras.Model
.
The tff.learning.models.ReconstructionModel
returned by this function uses
keras_model
for its forward pass and autodifferentiation steps. During
reconstruction, variables in local_layers
are initialized and trained.
Post-reconstruction, variables in global_layers
are trained and aggregated
on the server. All variables must be partitioned between global and local
layers, without overlap.
Args | |
---|---|
keras_model
|
A tf.keras.Model object that is not compiled.
|
global_layers
|
Iterable of global layers to be aggregated across users. All trainable and non-trainable model variables that can be aggregated on the server should be included in these layers. |
local_layers
|
Iterable of local layers not shared with the server. All trainable and non-trainable model variables that should not be aggregated on the server should be included in these layers. |
input_spec
|
A structure of tf.TensorSpec s specifying the type of
arguments the model expects. Notice this must be a compound structure of
two elements, specifying both the data fed into the model to generate
predictions, as its first element, as well as the expected type of the
ground truth as its second.
|
Returns | |
---|---|
A tff.learning.models.ReconstructionModel object.
|
Raises | |
---|---|
TypeError
|
If keras_model is not an instance of tf.keras.Model .
|
ValueError
|
If keras_model was compiled, if input_spec has unexpected
structure (e.g., has more than two elements), if global_layers or
local_layers contains layers that are not in keras_model , or if
global_layers and local_layers are not disjoint in their variables.
|
from_keras_model_and_variables
@classmethod
from_keras_model_and_variables( keras_model: tf.keras.Model, *, global_trainable_variables: Iterable[tf.Variable], global_non_trainable_variables: Iterable[tf.Variable], local_trainable_variables: Iterable[tf.Variable], local_non_trainable_variables: Iterable[tf.Variable], input_spec: Any ) -> 'ReconstructionModel'
Builds a tff.learning.models.ReconstructionModel
from a tf.keras.Model
.
The tff.learning.models.ReconstructionModel
returned by this function uses
keras_model
for its forward pass and autodifferentiation steps. During
reconstruction, variables in local_trainable_variables
are initialized
and trained, and variables in local_non_trainable_variables
are
initialized. Post-reconstruction, variables in global_trainable_variables
are trained and aggregated on the server. All keras_model variables must be
partitioned between global_trainable_variables,
global_non_trainable_variables, local_trainable_variables, and
local_non_trainable_variables, without overlap.
Args | |
---|---|
keras_model
|
A tf.keras.Model object that is not compiled.
|
global_trainable_variables
|
The trainable variables to associate with the post-reconstruction phase. |
global_non_trainable_variables
|
The non-trainable variables to associate with the post-reconstruction phase. |
local_trainable_variables
|
The trainable variables to associate with the reconstruction phase. |
local_non_trainable_variables
|
The non-trainable variables to associate with the reconstruction phase. |
input_spec
|
A structure of tf.TensorSpec s specifying the type of
arguments the model expects. Notice this must be a compound structure of
two elements, specifying both the data fed into the model to generate
predictions, as its first element, as well as the expected type of the
ground truth as its second.
|
Returns | |
---|---|
A tff.learning.models.ReconstructionModel object.
|
Raises | |
---|---|
TypeError
|
If keras_model is not an instance of tf.keras.Model .
|
ValueError
|
If keras_model was compiled, if keras_model is not already
built, if input_spec has unexpected structure (e.g., has more than two
elements), if global_layers or local_layers contains layers that are
not in keras_model , or if global_layers and local_layers are not
disjoint in their variables.
|
get_global_variables
@classmethod
get_global_variables( model: 'ReconstructionModel' ) ->
tff.learning.models.ModelWeights
Gets global variables from model
as ModelWeights
.
get_local_variables
@classmethod
get_local_variables( model: 'ReconstructionModel' ) ->
tff.learning.models.ModelWeights
Gets local variables from a Model
as ModelWeights
.
has_only_global_variables
@classmethod
has_only_global_variables( model: 'ReconstructionModel' ) -> bool
Returns True
if the model has no local variables.
read_metric_variables
@classmethod
read_metric_variables( metrics: list[tf.keras.metrics.Metric] ) -> collections.OrderedDict[str, list[tf.Tensor]]
Reads values from Keras metric variables.
simple_dataset_split_fn
@classmethod
simple_dataset_split_fn( client_dataset: tf.data.Dataset ) -> tuple[tf.data.Dataset, tf.data.Dataset]
A ReconstructionDatasetSplitFn
that returns the original client data.
Both the reconstruction data and post-reconstruction data will result from
iterating over the same tf.data.Dataset. Note that depending on any
preprocessing steps applied to client tf.data.Datasets, this may not produce
exactly the same data in the same order for both reconstruction and
post-reconstruction. For example, if
client_dataset.shuffle(reshuffle_each_iteration=True)
was applied,
post-reconstruction data will be in a different order than reconstruction
data.
Args | |
---|---|
client_dataset
|
tf.data.Dataset representing client data.
|
Returns | |
---|---|
A tuple of two tf.data.Datasets , the first to be used for
reconstruction, the second to be used for post-reconstruction.
|