View source on GitHub
|
Changes the layout of tensor to the same as layout_tensor.
tf.experimental.dtensor.relayout_like(
tensor: tf.Tensor,
layout_tensor: tf.Tensor,
name: Optional[str] = None
) -> tf.Tensor
relayout_like is often used inside a tf.function, to ensure a tensor is
placed to the same mesh and with the same layout as another tensor.
The backward gradient of a relayout is a relayout_like operation, to
ensure the backward tensor has the same layout as the forward input tensor:
@ops.RegisterGradient("Relayout")
def _relayout_gradient(op, grad):
return relayout_like(grad, layout_input=op.inputs[0])
Here is another illustrative example:
@tf.function
def func(x):
z = tf.ones(x.shape)
z = dtensor.relayout_like(z, x)
return x + z
with dtensor.default_mesh(cpu_mesh):
x = tf.ones((4, 4))
with dtensor.default_mesh(gpu_mesh):
y = func(x)
# y would be on the cpu mesh, following the mesh of x.
Returns | |
|---|---|
| A DTensor output from the RelayoutLike op. |
View source on GitHub