Creates a baseline task for tag prediction on Stack Overflow.
tff.simulation.baselines.stackoverflow.create_tag_prediction_task(
train_client_spec: tff.simulation.baselines.ClientSpec
,
eval_client_spec: Optional[tff.simulation.baselines.ClientSpec
] = None,
word_vocab_size: int = constants.DEFAULT_WORD_VOCAB_SIZE,
tag_vocab_size: int = constants.DEFAULT_TAG_VOCAB_SIZE,
cache_dir: Optional[str] = None,
use_synthetic_data: bool = False
) -> tff.simulation.baselines.BaselineTask
The goal of the task is to predict the tags associated to a post based on a
bag-of-words representation of the post.
Args |
train_client_spec
|
A tff.simulation.baselines.ClientSpec specifying how to
preprocess train client data.
|
eval_client_spec
|
An optional tff.simulation.baselines.ClientSpec
specifying how to preprocess evaluation client data. If set to None , the
evaluation datasets will use a batch size of 64 with no extra
preprocessing.
|
word_vocab_size
|
Integer dictating the number of most frequent words in the
entire corpus to use for the task's vocabulary. By default, this is set to
tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE .
|
tag_vocab_size
|
Integer dictating the number of most frequent tags in the
entire corpus to use for the task's labels. By default, this is set to
tff.simulation.baselines.stackoverflow.DEFAULT_TAG_VOCAB_SIZE .
|
cache_dir
|
An optional directory to cache the downloadeded datasets. If
None , they will be cached to ~/.tff/ .
|
use_synthetic_data
|
A boolean indicating whether to use synthetic Stack
Overflow data. This option should only be used for testing purposes, in
order to avoid downloading the entire Stack Overflow dataset. Synthetic
word vocabularies and tag vocabularies will also be used (not necessarily
of sizes word_vocab_size and tag_vocab_size ).
|