View source on GitHub
|
Represents the layout information of a DTensor.
tf.experimental.dtensor.Layout(
sharding_specs: List[str],
mesh: tf.experimental.dtensor.Mesh
)
A layout describes how a distributed tensor is partitioned across a mesh (and
thus across devices). For each axis of the tensor, the corresponding
sharding spec indicates which dimension of the mesh it is sharded over. A
special sharding spec UNSHARDED indicates that axis is replicated on
all the devices of that mesh.
For example, let's consider a 1-D mesh:
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], [("x", 6)])
This mesh arranges 6 TPU devices into a 1-D array. Layout([UNSHARDED], mesh)
is a layout for rank-1 tensor which is replicated on the 6 devices.
For another example, let's consider a 2-D mesh:
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"],
[("x", 3), ("y", 2)])
This mesh arranges 6 TPU devices into a 3x2 2-D array.
Layout(["x", UNSHARDED], mesh) is a layout for rank-2 tensor whose first
axis is sharded on mesh dimension "x" and the second axis is replicated. If we
place np.arange(6).reshape((3, 2)) using this layout, the individual
components tensors would look like:
Device | Component
TPU:0 [[0, 1]]
TPU:1 [[0, 1]]
TPU:2 [[2, 3]]
TPU:3 [[2, 3]]
TPU:4 [[4, 5]]
TPU:5 [[4, 5]]
Methods
as_proto
as_proto() -> layout_pb2.LayoutProto
Create a proto representation of a layout.
batch_sharded
@staticmethodbatch_sharded( mesh:tf.experimental.dtensor.Mesh, batch_dim: str, rank: int ) -> 'Layout'
Returns a layout sharded on batch dimension.
delete
delete(
dims: List[int]
) -> 'Layout'
Returns the layout with the give dimensions deleted.
from_str
@staticmethodfrom_str( layout_str: bytes ) -> 'Layout'
Creates an instance from a serialized Protobuf binary string.
from_string
@staticmethodfrom_string( layout_str: str ) -> 'Layout'
Creates an instance from a human-readable string.
inner_sharded
@staticmethodinner_sharded( mesh:tf.experimental.dtensor.Mesh, inner_dim: str, rank: int ) -> 'Layout'
Returns a layout sharded on inner dimension.
is_fully_replicated
is_fully_replicated() -> bool
Returns True if all tensor axes are replicated.
mesh_proto
mesh_proto() -> layout_pb2.MeshProto
Returns the underlying mesh in Protobuf format.
num_shards
num_shards(
idx: int
) -> int
Returns the number of shards for tensor dimension idx.
offset_to_shard
offset_to_shard()
Mapping from offset in a flattened list to shard index.
offset_tuple_to_global_index
offset_tuple_to_global_index(
offset_tuple
)
Mapping from offset to index in global tensor.
replicated
@staticmethodreplicated( mesh:tf.experimental.dtensor.Mesh, rank: int ) -> 'Layout'
Returns a replicated layout of rank rank.
serialized_string
serialized_string() -> bytes
Returns a serialized Protobuf binary string representation.
to_string
to_string() -> str
Returns a human-readable string representation.
unravel
unravel(
unpacked_tensors: List[np.ndarray]
) -> np.ndarray
Convert a flattened list of shards into a sharded array.
__eq__
__eq__(
other
) -> bool
Return self==value.
View source on GitHub