Tune in to the first Women in ML Symposium this Tuesday, October 19 at 9am PST Register now


An Op to exchange data across TPU replicas.

On each replica, the input is split into split_count blocks along split_dimension and send to the other replicas given group_assignment. After receiving split_count - 1 blocks from other replicas, we concatenate the blocks along concat_dimension as the output.

For example, suppose there are 2 TPU replicas: replica 0 receives input: [[A, B]] replica 1 receives input: [[C, D]]

group_assignment=[[0, 1]] concat_dimension=0 split_dimension=1 split_count=2

replica 0's output: [[A], [C]] replica 1's output: [[B], [D]]

input A Tensor. Must be one of the following types: float32, float64, int32, uint8, int16, int8, complex64, int64, qint8, quint8, qint32, bfloat16, uint16, complex128, half, uint32, uint64, bool. The local input to the sum.
group_assignment A Tensor of type int32. An int32 tensor with shape [num_groups, num_replicas_per_group]. group_assignment[i] represents the replica ids in the ith subgroup.
concat_dimension An int. The dimension number to concatenate.
split_dimension An int. The dimension number to split.
split_count An int. The number of splits, this number must equal to the sub-group size(group_assignment.get_shape()[1])
name A name for the operation (optional).

A Tensor. Has the same type as input.