tf.experimental.dtensor.Layout

Represents the layout information of a DTensor.

A layout describes how a distributed tensor is partitioned across a mesh (and thus across devices). For each axis of the tensor, the corresponding sharding spec indicates which dimension of the mesh it is sharded over. A special sharding spec UNSHARDED indicates that axis is replicated on all the devices of that mesh.

For example, let's consider a 1-D mesh:

Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], [("x", 6)])

This mesh arranges 6 TPU devices into a 1-D array. Layout([UNSHARDED], mesh) is a layout for rank-1 tensor which is replicated on the 6 devices.

For another example, let's consider a 2-D mesh:

Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"],
     [("x", 3), ("y", 2)])

This mesh arranges 6 TPU devices into a 3x2 2-D array. Layout(["x", UNSHARDED], mesh) is a layout for rank-2 tensor whose first axis is sharded on mesh dimension "x" and the second axis is replicated. If we place np.arange(6).reshape((3, 2)) using this layout, the individual components tensors would look like:

Device  |  Component
 TPU:0     [[0, 1]]
 TPU:1     [[0, 1]]
 TPU:2     [[2, 3]]
 TPU:3     [[2, 3]]
 TPU:4     [[4, 5]]
 TPU:5     [[4, 5]]

sharding_specs List of sharding specifications, each corresponding to a tensor axis. Each specification (dim_sharding) can either be a mesh dimension or the special value UNSHARDED.
mesh A mesh configuration for the Tensor.

Methods

as_proto

View source

Create a proto representation of a layout.

batch_sharded

View source

Returns a layout sharded on batch dimension.

delete

View source

Returns the layout with the give dimensions deleted.

from_str

View source

Creates an instance from a serialized Protobuf binary string.

from_string

View source

Creates an instance from a human-readable string.

inner_sharded

View source

Returns a layout sharded on inner dimension.

is_fully_replicated

View source

Returns True if all tensor axes are replicated.

mesh_proto

View source

Returns the underlying mesh in Protobuf format.

num_shards

View source

Returns the number of shards for tensor dimension idx.

offset_to_shard

View source

Mapping from offset in a flattened list to shard index.

offset_tuple_to_global_index

View source

Mapping from offset to index in global tensor.

replicated

View source

Returns a replicated layout of rank rank.

serialized_string

View source

Returns a serialized Protobuf binary string representation.

to_string

View source

Returns a human-readable string representation.

unravel

View source

Convert a flattened list of shards into a sharded array.

__eq__

View source

Return self==value.