tfds.split_for_jax_process

Returns the subsplit of the data for the process.

In distributed setting, all process/hosts should get a non-overlapping, equally sized slice of the entire data. This function takes as input a split and extracts the slice for the current process index.

Usage:

tfds.load(..., split=tfds.split_for_jax_process('train'))

This funtion is an alias for:

tfds.even_splits(split, n=jax.process_count())[jax.process_index()]

By default, if examples can't be evenly distributed across processes, you can drop extra examples with drop_remainder=True.

split Split to distribute across host (e.g. train[75%:], train[:800]+validation[:100]).
process_index Process index in [0, count). Defaults to jax.process_index().
process_count Number of processes. Defaults to jax.process_count().
drop_remainder Drop examples if the number of examples in the datasets is not evenly divisible by n. If False, examples are distributed evenly across subsplits, starting by the first. For example, if there is 11 examples with n=3, splits will contain [4, 4, 3] examples respectivelly.

subsplit The sub-split of the given split for the current process_index.