tf.keras.distribution.DataParallel

Distribution for data parallelism.

You can choose to create this instance by either specifying the device_mesh or devices arguments (but not both).

The device_mesh argument is expected to be a DeviceMesh instance, and is expected to be 1D only. In case that the mesh has multiple axes, then the first axis will be treated as the data parallel dimension (and a warning will be raised).

When a list of devices are provided, they will be used to construct a 1D mesh.

When both mesh and devices are absent, then list_devices() will be used to detect any available devices and create a 1D mesh from them.

device_mesh Optional DeviceMesh instance.
devices Optional list of devices.

device_mesh

Methods

distribute_dataset

View source

Create a distributed dataset instance from the original user dataset.

Args
dataset the original global dataset instance. Only tf.data.Dataset is supported at the moment.

Returns
a sharded tf.data.Dataset instance, which will produce data for the current local worker/process.

get_data_layout

View source

Retrieve the TensorLayout for the input data.

Args
data_shape shape for the input data in list or tuple format.

Returns
The TensorLayout for the data, which can be used by backend.distribute_value() to redistribute a input data.

get_tensor_layout

View source

Retrieve the TensorLayout for the intermediate tensor.

Args
path a string path for the corresponding tensor.

return: The TensorLayout for the intermediate tensor, which can be used by backend.relayout() to reshard the tensor. Could also return None.

get_variable_layout

View source

Retrieve the TensorLayout for the variable.

Args
variable A KerasVariable instance.

return: The TensorLayout for the variable, which can be used by backend.distribute_value() to redistribute a variable.

scope

View source

Context manager to make the Distribution current.