View source on GitHub |
Represents the layout information of a DTensor.
tf.experimental.dtensor.Layout(
sharding_specs: List[str],
mesh: tf.experimental.dtensor.Mesh
)
A layout describes how a distributed tensor is partitioned across a mesh (and
thus across devices). For each axis of the tensor, the corresponding
sharding spec indicates which dimension of the mesh it is sharded over. A
special sharding spec UNSHARDED
indicates that axis is replicated on
all the devices of that mesh.
For example, let's consider a 1-D mesh:
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], [("x", 6)])
This mesh arranges 6 TPU devices into a 1-D array. Layout([UNSHARDED], mesh)
is a layout for rank-1 tensor which is replicated on the 6 devices.
For another example, let's consider a 2-D mesh:
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"],
[("x", 3), ("y", 2)])
This mesh arranges 6 TPU devices into a 3x2
2-D array.
Layout(["x", UNSHARDED], mesh)
is a layout for rank-2 tensor whose first
axis is sharded on mesh dimension "x" and the second axis is replicated. If we
place np.arange(6).reshape((3, 2))
using this layout, the individual
components tensors would look like:
Device | Component
TPU:0 [[0, 1]]
TPU:1 [[0, 1]]
TPU:2 [[2, 3]]
TPU:3 [[2, 3]]
TPU:4 [[4, 5]]
TPU:5 [[4, 5]]
Methods
as_proto
as_proto() -> layout_pb2.LayoutProto
Create a proto representation of a layout.
batch_sharded
@staticmethod
batch_sharded( mesh:
tf.experimental.dtensor.Mesh
, batch_dim: str, rank: int ) -> 'Layout'
Returns a layout sharded on batch dimension.
delete
delete(
dims: List[int]
) -> 'Layout'
Returns the layout with the give dimensions deleted.
from_str
@staticmethod
from_str( layout_str: bytes ) -> 'Layout'
Creates an instance from a serialized Protobuf binary string.
from_string
@staticmethod
from_string( layout_str: str ) -> 'Layout'
Creates an instance from a human-readable string.
inner_sharded
@staticmethod
inner_sharded( mesh:
tf.experimental.dtensor.Mesh
, inner_dim: str, rank: int ) -> 'Layout'
Returns a layout sharded on inner dimension.
is_fully_replicated
is_fully_replicated() -> bool
Returns True if all tensor axes are replicated.
mesh_proto
mesh_proto() -> layout_pb2.MeshProto
Returns the underlying mesh in Protobuf format.
num_shards
num_shards(
idx: int
) -> int
Returns the number of shards for tensor dimension idx
.
offset_to_shard
offset_to_shard()
Mapping from offset in a flattened list to shard index.
offset_tuple_to_global_index
offset_tuple_to_global_index(
offset_tuple
)
Mapping from offset to index in global tensor.
replicated
@staticmethod
replicated( mesh:
tf.experimental.dtensor.Mesh
, rank: int ) -> 'Layout'
Returns a replicated layout of rank rank
.
serialized_string
serialized_string() -> bytes
Returns a serialized Protobuf binary string representation.
to_string
to_string() -> str
Returns a human-readable string representation.
unravel
unravel(
unpacked_tensors: List[np.ndarray]
) -> np.ndarray
Convert a flattened list of shards into a sharded array.
__eq__
__eq__(
other
) -> bool
Return self==value.