View source on GitHub |
Creates a baseline task for character recognition on EMNIST.
tff.simulation.baselines.emnist.create_character_recognition_task(
train_client_spec: tff.simulation.baselines.ClientSpec
,
eval_client_spec: Optional[tff.simulation.baselines.ClientSpec
] = None,
model_id: Union[str, tff.simulation.baselines.emnist.CharacterRecognitionModel
] = 'cnn_dropout',
only_digits: bool = False,
cache_dir: Optional[str] = None,
use_synthetic_data: bool = False,
debug_seed: Optional[int] = None
) -> tff.simulation.baselines.BaselineTask
The goal of the task is to minimize the sparse categorical crossentropy
between the output labels of the model and the true label of the image. When
only_digits = True
, there are 10 possible labels (the digits 0-9), while
when only_digits = False
, there are 62 possible labels (both numbers and
letters).
This classification can be done using a number of different models, specified
using the model_id
argument. Below we give a list of the different models
that can be used:
model_id = cnn_dropout
: A moderately sized convolutional network. Uses two convolutional layers, a max pooling layer, and dropout, followed by two dense layers.model_id = cnn
: A moderately sized convolutional network, without any dropout layers. Matches the architecture of the convolutional network used by (McMahan et al., 2017) for the purposes of testing the FedAvg algorithm.model_id = 2nn
: A densely connected network with 2 hidden layers, each with 200 hidden units and ReLU activations.
Args | |
---|---|
train_client_spec
|
A tff.simulation.baselines.ClientSpec specifying how to
preprocess train client data.
|
eval_client_spec
|
An optional tff.simulation.baselines.ClientSpec
specifying how to preprocess evaluation client data. If set to None , the
evaluation datasets will use a batch size of 64 with no extra
preprocessing.
|
model_id
|
A string identifier for a character recognition model. Must be one of 'cnn_dropout', 'cnn', or '2nn'. These correspond respectively to a CNN model with dropout, a CNN model with no dropout, and a densely connected network with two hidden layers of width 200. |
only_digits
|
A boolean indicating whether to use the full EMNIST-62 dataset
containing 62 alphanumeric classes (True ) or the smaller EMNIST-10
dataset with only 10 numeric classes (False ).
|
cache_dir
|
An optional directory to cache the downloadeded datasets. If
None , they will be cached to ~/.tff/ .
|
use_synthetic_data
|
A boolean indicating whether to use synthetic EMNIST data. This option should only be used for testing purposes, in order to avoid downloading the entire EMNIST dataset. |
debug_seed
|
An optional integer seed to force deterministic model initialization. This is intended for unittesting. |
Returns | |
---|---|
A tff.simulation.baselines.BaselineTask .
|