Gets the current worker index and the total number of workers.

This method should be called by a worker in a tf.function called in the worker context. In practice, this method should be called in the in a the dataset_fn(context) method.

Currently, context is ignored as it is not populated by the ParameterServerStrategyV2. However, context should still be provided for compatibility with future API changes.

paths = [...list of dataset files]

def dataset_fn(context: Optional[distribute_lib.InputContext] = None): # Distributed dataset_fn.

ds_path =

if context is not None: current_worker = keras.get_worker_idx_and_num_workers(context) assert current_worker.num_workers > 1, "Not distributed dataset reading" ds_path = ds_path.shard( num_shards=current_worker.num_workers, index=current_worker.worker_index)

# Load the examples from "ds_path", for example with #

def read_csv_file(path): csv_columns = [ ... ] return, csv_columns, header=False)

ds_columns = ds_path.interleave(read_csv_file)

def extract_label(*columns): return columns[0:-1], columns[-1]


context Distribution strategy input context.

Return the index of the worker (tensor) and the total number of workers (integer).