tfm.utils.cross_replica_concat

Concatenates the given value across (GPU/TPU) cores, along axis.

In general, each core ("replica") will pass a replica-specific value as value (corresponding to some element of a data-parallel computation taking place across replicas).

The resulting concatenated Tensor will have the same shape as value for all dimensions except axis, where it will be larger by a factor of the number of replicas. It will also have the same dtype as value.

The position of a given replica's value within the resulting concatenation is determined by that replica's replica ID. For example:

With value for replica 0 given as

0 0 0
0 0 0

and value for replica 1 given as

1 1 1
1 1 1

the resulting concatenation along axis 0 will be

0 0 0
0 0 0
1 1 1
1 1 1

and this result will be identical across all replicas.

Note that this API only works in TF2 with tf.distribute.

value The Tensor to concatenate across replicas. Each replica will have a different value for this Tensor, and these replica-specific values will be concatenated.
axis The axis along which to perform the concatenation as a Python integer (not a Tensor). E.g., axis=0 to concatenate along the batch dimension.
name A name for the operation (used to create a name scope).

The result of concatenating value along axis across replicas.

RuntimeError when the batch (0-th) dimension is None.