An Op to exchange data across TPU replicas.
tf.raw_ops.AllToAll(
    input,
    group_assignment,
    concat_dimension,
    split_dimension,
    split_count,
    name=None
)
On each replica, the input is split into split_count blocks along
split_dimension and send to the other replicas given group_assignment. After
receiving split_count - 1 blocks from other replicas, we concatenate the
blocks along concat_dimension as the output.
For example, suppose there are 2 TPU replicas:
replica 0 receives input: [[A, B]]
replica 1 receives input: [[C, D]]
group_assignment=[[0, 1]]
concat_dimension=0
split_dimension=1
split_count=2
replica 0's output: [[A], [C]]
replica 1's output: [[B], [D]]
| Returns | |
|---|---|
| A Tensor. Has the same type asinput. |