View source on GitHub |
Gets the current worker index and the total number of workers.
tfdf.keras.get_worker_idx_and_num_workers(
context: distribute_lib.InputContext
) -> WorkerIndex
This method should be called by a worker in a tf.function called in the worker
context. In practice, this method should be called in the in a the
dataset_fn(context)
method.
Currently, context
is ignored as it is not populated by the
ParameterServerStrategyV2
. However, context
should still be provided for
compatibility with future API changes.
Usage examples | |
---|---|
paths = [...list of dataset files]
def dataset_fn(context: Optional[distribute_lib.InputContext] = None): # Distributed dataset_fn. ds_path = tf.data.Dataset.from_tensor_slices(paths) if context is not None: current_worker = keras.get_worker_idx_and_num_workers(context) assert current_worker.num_workers > 1, "Not distributed dataset reading" ds_path = ds_path.shard( num_shards=current_worker.num_workers, index=current_worker.worker_index) # Load the examples from "ds_path", for example with
# def read_csv_file(path): csv_columns = [ ... ] return tf.data.experimental.CsvDataset(path, csv_columns, header=False) ds_columns = ds_path.interleave(read_csv_file) def extract_label(*columns): return columns[0:-1], columns[-1] return ds_columns.map(extract_label).batch(batch_size) |
Args | |
---|---|
context
|
Distribution strategy input context. |
Returns | |
---|---|
Return the index of the worker (tensor) and the total number of workers (integer). |