Returns the subsplit of the data for the process.
tfds.split_for_jax_process(
split: str,
*,
process_index: tfds.typing.Dim
= None,
process_count: tfds.typing.Dim
= None,
drop_remainder: bool = False
) -> tfds.typing.SplitArg
In distributed setting, all process/hosts should get a non-overlapping, equally sized slice of the entire data. This function takes as input a split and extracts the slice for the current process index.
Usage:
tfds.load(..., split=tfds.split_for_jax_process('train'))
This funtion is an alias for:
tfds.even_splits(split, n=jax.process_count())[jax.process_index()]
By default, if examples can't be evenly distributed across processes, you can
drop extra examples with drop_remainder=True
.