AllToAll

public final class AllToAll

An Op to exchange data across TPU replicas.

On each replica, the input is split into `split_count` blocks along `split_dimension` and send to the other replicas given group_assignment. After receiving `split_count` - 1 blocks from other replicas, we concatenate the blocks along `concat_dimension` as the output.

For example, suppose there are 2 TPU replicas: replica 0 receives input: `[[A, B]]` replica 1 receives input: `[[C, D]]`

group_assignment=`[[0, 1]]` concat_dimension=0 split_dimension=1 split_count=2

replica 0's output: `[[A], [C]]` replica 1's output: `[[B], [D]]`

Public Methods

Output<T>
asOutput()
Returns the symbolic handle of a tensor.
static <T> AllToAll<T>
create(Scope scope, Operand<T> input, Operand<Integer> groupAssignment, Long concatDimension, Long splitDimension, Long splitCount)
Factory method to create a class wrapping a new AllToAll operation.
Output<T>
output()
The exchanged result.

Inherited Methods

org.tensorflow.op.PrimitiveOp
final boolean
equals(Object obj)
final int
Operation
op()
Returns the underlying Operation
final String
boolean
equals(Object arg0)
final Class<?>
getClass()
int
hashCode()
final void
notify()
final void
notifyAll()
String
toString()
final void
wait(long arg0, int arg1)
final void
wait(long arg0)
final void
wait()
org.tensorflow.Operand
abstract Output<T>
asOutput()
Returns the symbolic handle of a tensor.

Public Methods

public Output<T> asOutput ()

Returns the symbolic handle of a tensor.

Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is used to obtain a symbolic handle that represents the computation of the input.

public static AllToAll<T> create (Scope scope, Operand<T> input, Operand<Integer> groupAssignment, Long concatDimension, Long splitDimension, Long splitCount)

Factory method to create a class wrapping a new AllToAll operation.

Parameters
scope current scope
input The local input to the sum.
groupAssignment An int32 tensor with shape [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the replica ids in the ith subgroup.
concatDimension The dimension number to concatenate.
splitDimension The dimension number to split.
splitCount The number of splits, this number must equal to the sub-group size(group_assignment.get_shape()[1])
Returns
  • a new instance of AllToAll

public Output<T> output ()

The exchanged result.