Represents the layout information of a DTensor.
tf.experimental.dtensor.Layout(
sharding_specs: List[str],
mesh: tf.experimental.dtensor.Mesh
)
Used in the notebooks
Used in the guide | Used in the tutorials |
---|---|
A layout describes how a distributed tensor is partitioned across a mesh (and
thus across devices). For each axis of the tensor, the corresponding
sharding spec indicates which dimension of the mesh it is sharded over. A
special sharding spec UNSHARDED
indicates that axis is replicated on
all the devices of that mesh.
Refer to DTensor Concepts for in depth discussion and examples.
For example, let's consider a 1-D mesh:
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], [("x", 6)])
This mesh arranges 6 TPU devices into a 1-D array. Layout([UNSHARDED], mesh)
is a layout for rank-1 tensor which is replicated on the 6 devices.
For another example, let's consider a 2-D mesh:
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"],
[("x", 3), ("y", 2)])
This mesh arranges 6 TPU devices into a 3x2
2-D array.
Layout(["x", UNSHARDED], mesh)
is a layout for rank-2 tensor whose first
axis is sharded on mesh dimension "x" and the second axis is replicated. If we
place np.arange(6).reshape((3, 2))
using this layout, the individual
components tensors would look like:
Device | Component
TPU:0 [[0, 1]]
TPU:1 [[0, 1]]
TPU:2 [[2, 3]]
TPU:3 [[2, 3]]
TPU:4 [[4, 5]]
TPU:5 [[4, 5]]
Methods
as_proto
as_proto()
as_proto(self: tensorflow.python._pywrap_dtensor_device.Layout) -> tensorflow::dtensor::LayoutProto
Returns the LayoutProto protobuf message.
batch_sharded
@classmethod
batch_sharded( mesh:
tf.experimental.dtensor.Mesh
, batch_dim: str, rank: int, axis: int = 0 ) -> 'Layout'
Returns a layout sharded on batch dimension.
delete
delete(
dims: List[int]
) -> 'Layout'
Returns the layout with the give dimensions deleted.
from_device
@classmethod
from_device( device: str ) -> 'Layout'
Constructs a single device layout from a single device mesh.
from_proto
@classmethod
from_proto( layout_proto: layout_pb2.LayoutProto ) -> 'Layout'
Creates an instance from a LayoutProto.
from_single_device_mesh
@classmethod
from_single_device_mesh( mesh:
tf.experimental.dtensor.Mesh
) -> 'Layout'
Constructs a single device layout from a single device mesh.
from_string
@classmethod
from_string( layout_str: str ) -> 'Layout'
Creates an instance from a human-readable string.
global_shape_from_local_shape
global_shape_from_local_shape()
global_shape_from_local_shape(self: tensorflow.python._pywrap_dtensor_device.Layout, local_shape: List[int]) -> tuple
Returns the global shape computed from this local shape.
inner_sharded
@classmethod
inner_sharded( mesh:
tf.experimental.dtensor.Mesh
, inner_dim: str, rank: int ) -> 'Layout'
Returns a layout sharded on inner dimension.
is_batch_parallel
is_batch_parallel()
is_batch_parallel(self: tensorflow.python._pywrap_dtensor_device.Layout) -> bool
is_fully_replicated
is_fully_replicated()
is_fully_replicated(self: tensorflow.python._pywrap_dtensor_device.Layout) -> bool
Returns True if all tensor axes are replicated.
is_single_device
is_single_device()
is_single_device(self: tensorflow.python._pywrap_dtensor_device.Layout) -> bool
Returns True if the Layout represents a non-distributed device.
local_shape_from_global_shape
local_shape_from_global_shape()
local_shape_from_global_shape(self: tensorflow.python._pywrap_dtensor_device.Layout, global_shape: List[int]) -> tuple
Returns the local shape computed from this global shape.
num_shards
num_shards()
num_shards(self: tensorflow.python._pywrap_dtensor_device.Layout, idx: int) -> int
Returns the number of shards for tensor dimension idx
.
offset_to_shard
offset_to_shard()
Mapping from offset in a flattened list to shard index.
offset_tuple_to_global_index
offset_tuple_to_global_index(
offset_tuple
)
Mapping from offset to index in global tensor.
replicated
@classmethod
replicated( mesh:
tf.experimental.dtensor.Mesh
, rank: int ) -> 'Layout'
Returns a replicated layout of rank rank
.
to_parted
to_parted() -> 'Layout'
Returns a "parted" layout from a static layout.
A parted layout contains axes that are treated as independent by most of SPMD expanders.
FIXME(b/285905569): The exact semantics is still being investigated.
to_string
to_string()
to_string(self: tensorflow.python._pywrap_dtensor_device.Layout) -> str
__eq__
__eq__()
eq(self: tensorflow.python._pywrap_dtensor_device.Layout, arg0: tensorflow.python._pywrap_dtensor_device.Layout) -> bool