tff.simulation.baselines.cifar100.create_image_classification_task
Creates a baseline task for image classification on CIFAR-100.
tff.simulation.baselines.cifar100.create_image_classification_task(
train_client_spec: tff.simulation.baselines.ClientSpec
,
eval_client_spec: Optional[tff.simulation.baselines.ClientSpec
] = None,
model_id: Union[str, tff.simulation.baselines.cifar100.ResnetModel
] = 'resnet18',
crop_height: int = DEFAULT_CROP_HEIGHT,
crop_width: int = DEFAULT_CROP_WIDTH,
distort_train_images: bool = False,
cache_dir: Optional[str] = None,
use_synthetic_data: bool = False
) -> tff.simulation.baselines.BaselineTask
The goal of the task is to minimize the sparse categorical crossentropy
between the output labels of the model and the true label of the image.
Args |
train_client_spec
|
A tff.simulation.baselines.ClientSpec specifying how to
preprocess train client data.
|
eval_client_spec
|
An optional tff.simulation.baselines.ClientSpec
specifying how to preprocess evaluation client data. If set to None , the
evaluation datasets will use a batch size of 64 with no extra
preprocessing.
|
model_id
|
A string identifier for a digit recognition model. Must be one of
resnet18 , resnet34 , resnet50 , resnet101 and resnet152. These
correspond to various ResNet architectures. Unlike standard ResNet
architectures though, the batch normalization layers are replaced with
group normalization.
</td>
</tr><tr>
<td> crop_height<a id="crop_height"></a>
</td>
<td>
An integer specifying the desired height for cropping images.
Must be between 1 and 32 (the height of uncropped CIFAR-100 images). By
default, this is set to
<a href="../../../../tff/simulation/baselines/cifar100#DEFAULT_CROP_HEIGHT"><code>tff.simulation.baselines.cifar100.DEFAULT_CROP_HEIGHT</code></a>.
</td>
</tr><tr>
<td> crop_width<a id="crop_width"></a>
</td>
<td>
An integer specifying the desired width for cropping images.
Must be between 1 and 32 (the width of uncropped CIFAR-100 images). By
default this is set to
<a href="../../../../tff/simulation/baselines/cifar100#DEFAULT_CROP_WIDTH"><code>tff.simulation.baselines.cifar100.DEFAULT_CROP_WIDTH</code></a>.
</td>
</tr><tr>
<td> distort_train_images<a id="distort_train_images"></a>
</td>
<td>
Whether to distort images in the train preprocessing
function.
</td>
</tr><tr>
<td> cache_dir<a id="cache_dir"></a>
</td>
<td>
An optional directory to cache the downloadeded datasets. If None, they will be cached to ~/.tff/.
</td>
</tr><tr>
<td> use_synthetic_data`
|
A boolean indicating whether to use synthetic CIFAR-100
data. This option should only be used for testing purposes, in order to
avoid downloading the entire CIFAR-100 dataset.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-20 UTC.
[null,null,["Last updated 2024-09-20 UTC."],[],[]]