Creates a baseline task of image classification on GLDv2.
tff.simulation.baselines.landmark.create_landmark_classification_task(
train_client_spec: tff.simulation.baselines.ClientSpec
,
eval_client_spec: Optional[tff.simulation.baselines.ClientSpec
] = None,
use_gld23k: bool = False,
cache_dir: Optional[str] = None,
use_synthetic_data: bool = False,
debug_seed: Optional[int] = None
) -> tff.simulation.baselines.BaselineTask
The goal of the task is to minimize the sparse categorical cross entropy between the output labels of the model and the true label of the image. A MobilenetV2 model is created that expects input image data with a shape of [128, 128, 3] and group normalization layers with a group number of 8.
Args | |
---|---|
train_client_spec
|
A tff.simulation.baselines.ClientSpec specifying how to
preprocess train client data.
|
eval_client_spec
|
An optional tff.simulation.baselines.ClientSpec
specifying how to preprocess evaluation client data. If set to None , the
evaluation datasets will use a batch_size of 64.
|
use_gld23k
|
An optional boolean. When true, a smaller version of the GLDv2 landmark dataset will be loaded. This gld23k dataset is used for faster prototyping. |
cache_dir
|
An optional directory to cache the downloadeded datasets. If
non-specified, they will be cached to the default cache directory cache .
|
use_synthetic_data
|
An optional boolean indicating whether to use synthetic GLDv2 data. This option should only be used for testing purposes, in order to avoid downloading the entire GLDv2 dataset. |
debug_seed
|
An optional integer seed to force deterministic model initialization. This is intended for unittesting. |
Returns | |
---|---|
A tff.simulation.baselines.BaselineTask .
|