View source on GitHub
|
Concatenates the given value across (GPU/TPU) cores, along axis.
tfm.utils.cross_replica_concat(
value, axis, name='cross_replica_concat'
)
In general, each core ("replica") will pass a
replica-specific value as value (corresponding to some element of a
data-parallel computation taking place across replicas).
The resulting concatenated Tensor will have the same shape as value for
all dimensions except axis, where it will be larger by a factor of the
number of replicas. It will also have the same dtype as value.
The position of a given replica's value within the resulting concatenation
is determined by that replica's replica ID. For
example:
With value for replica 0 given as
0 0 0
0 0 0
and value for replica 1 given as
1 1 1
1 1 1
the resulting concatenation along axis 0 will be
0 0 0
0 0 0
1 1 1
1 1 1
and this result will be identical across all replicas.
Note that this API only works in TF2 with tf.distribute.
Returns | |
|---|---|
The result of concatenating value along axis across replicas.
|
Raises | |
|---|---|
RuntimeError
|
when the batch (0-th) dimension is None. |
View source on GitHub