View source on GitHub |
Represents a Mesh configuration over a certain list of Mesh Dimensions.
tf.experimental.dtensor.Mesh(
dim_names: List[str],
global_device_ids: np.ndarray,
local_device_ids: List[int],
local_devices: List[tf.compat.v1.DeviceSpec
],
mesh_name: str = '',
global_devices: Optional[List[tf_device.DeviceSpec]] = None
)
A mesh consists of named dimensions with sizes, which describe how a set of devices are arranged. Defining tensor layouts in terms of mesh dimensions allows us to efficiently determine the communication required when computing an operation with tensors of different layouts.
A mesh provides information not only about the placement of the tensors but
also the topology of the underlying devices. For example, we can group 8 TPUs
as a 1-D array for data parallelism or a 2x4
grid for (2-way) data
parallelism and (4-way) model parallelism.
Methods
as_proto
as_proto() -> layout_pb2.MeshProto
Returns mesh protobuffer.
contains_dim
contains_dim(
dim_name: str
) -> bool
Returns True if a Mesh contains the given dimension name.
coords
coords(
device_idx: int
) -> tf.Tensor
Converts the device index into a tensor of mesh coordinates.
device_type
device_type() -> str
Returns the device_type of a Mesh.
dim_size
dim_size(
dim_name: str
) -> int
Returns the size of a dimension.
from_proto
@staticmethod
from_proto( proto: layout_pb2.MeshProto ) -> 'Mesh'
Construct a mesh instance from input proto
.
from_string
@staticmethod
from_string( mesh_str: str ) -> 'Mesh'
Construct a mesh instance from input proto
.
host_mesh
host_mesh()
Returns the 1-1 mapped host mesh.
is_remote
is_remote() -> bool
Returns True if a Mesh contains only remote devices.
local_device_ids
local_device_ids() -> List[int]
Returns a list of local device IDs.
local_device_locations
local_device_locations() -> List[Dict[str, int]]
Returns a list of local device locations.
A device location is a dictionary from dimension names to indices on those dimensions.
local_devices
local_devices() -> List[str]
Returns a list of local device specs represented as strings.
min_global_device_id
min_global_device_id() -> int
Returns the minimum global device ID.
num_local_devices
num_local_devices() -> int
Returns the number of local devices.
shape
shape() -> List[int]
Returns the shape of the mesh.
to_string
to_string() -> str
Returns string representation of Mesh.
unravel_index
unravel_index()
Returns a dictionary from device ID to {dim_name: dim_index}.
For example, for a 3x2 mesh, return this:
{ 0: {'x': 0, 'y', 0},
1: {'x': 0, 'y', 1},
2: {'x': 1, 'y', 0},
3: {'x': 1, 'y', 1},
4: {'x': 2, 'y', 0},
5: {'x': 2, 'y', 1} }
__contains__
__contains__(
dim_name: str
) -> bool
__eq__
__eq__(
other
)
Return self==value.
__getitem__
__getitem__(
dim_name: str
) -> MeshDimension