Creates a baseline task of image classification on GLDv2.
tff.simulation.baselines.landmark.create_landmark_classification_task(
train_client_spec: tff.simulation.baselines.ClientSpec
,
eval_client_spec: Optional[tff.simulation.baselines.ClientSpec
] = None,
use_gld23k: bool = False,
cache_dir: Optional[str] = None,
use_synthetic_data: bool = False,
debug_seed: Optional[int] = None
) -> tff.simulation.baselines.BaselineTask
The goal of the task is to minimize the sparse categorical cross entropy
between the output labels of the model and the true label of the image. A
MobilenetV2 model is created that expects input image data with a shape of
[128, 128, 3] and group normalization layers with a group number of 8.
Args |
train_client_spec
|
A tff.simulation.baselines.ClientSpec specifying how to
preprocess train client data.
|
eval_client_spec
|
An optional tff.simulation.baselines.ClientSpec
specifying how to preprocess evaluation client data. If set to None , the
evaluation datasets will use a batch_size of 64.
|
use_gld23k
|
An optional boolean. When true, a smaller version of the GLDv2
landmark dataset will be loaded. This gld23k dataset is used for faster
prototyping.
|
cache_dir
|
An optional directory to cache the downloadeded datasets. If
non-specified, they will be cached to the default cache directory cache .
|
use_synthetic_data
|
An optional boolean indicating whether to use synthetic
GLDv2 data. This option should only be used for testing purposes, in order
to avoid downloading the entire GLDv2 dataset.
|
debug_seed
|
An optional integer seed to force deterministic model
initialization. This is intended for unittesting.
|