tff.simulation.baselines.stackoverflow.create_word_prediction_task
Creates a baseline task for next-word prediction on Stack Overflow.
tff.simulation.baselines.stackoverflow.create_word_prediction_task(
train_client_spec: tff.simulation.baselines.ClientSpec
,
eval_client_spec: Optional[tff.simulation.baselines.ClientSpec
] = None,
sequence_length: int = constants.DEFAULT_SEQUENCE_LENGTH,
vocab_size: int = constants.DEFAULT_WORD_VOCAB_SIZE,
num_out_of_vocab_buckets: int = 1,
cache_dir: Optional[str] = None,
use_synthetic_data: bool = False
) -> tff.simulation.baselines.BaselineTask
The goal of the task is to take sequence_length
words from a post and
predict the next word. Here, all posts are drawn from the Stack Overflow
forum, and a client corresponds to a user.
Args |
train_client_spec
|
A tff.simulation.baselines.ClientSpec specifying how to
preprocess train client data.
|
eval_client_spec
|
An optional tff.simulation.baselines.ClientSpec
specifying how to preprocess evaluation client data. If set to None , the
evaluation datasets will use a batch size of 64 with no extra
preprocessing.
|
sequence_length
|
A positive integer dictating the length of each word
sequence in a client's dataset. By default, this is set to
tff.simulation.baselines.stackoverflow.DEFAULT_SEQUENCE_LENGTH .
|
vocab_size
|
Integer dictating the number of most frequent words in the
entire corpus to use for the task's vocabulary. By default, this is set to
tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE .
|
num_out_of_vocab_buckets
|
The number of out-of-vocabulary buckets to use.
|
cache_dir
|
An optional directory to cache the downloadeded datasets. If
None , they will be cached to ~/.tff/ .
|
use_synthetic_data
|
A boolean indicating whether to use synthetic Stack
Overflow data. This option should only be used for testing purposes, in
order to avoid downloading the entire Stack Overflow dataset. A synthetic
vocabulary will also be used (not necessarily of the size vocab_size ).
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-20 UTC.
[null,null,["Last updated 2024-09-20 UTC."],[],[]]