|  View source on GitHub | 
Returns the subsplit of the data for the process.
tfds.split_for_jax_process(
    split: str,
    *,
    process_index: tfds.typing.Dim = None,
    process_count: tfds.typing.Dim = None,
    drop_remainder: bool = False
) -> tfds.typing.SplitArg
In distributed setting, all process/hosts should get a non-overlapping, equally sized slice of the entire data. This function takes as input a split and extracts the slice for the current process index.
Usage:
tfds.load(..., split=tfds.split_for_jax_process('train'))
This funtion is an alias for:
tfds.even_splits(split, n=jax.process_count())[jax.process_index()]
By default, if examples can't be evenly distributed across processes, you can
drop extra examples with drop_remainder=True.
| Returns | |
|---|---|
| subsplit | The sub-split of the given splitfor the currentprocess_index. |