Thanks for tuning in to Google I/O. View all sessions on demandWatch on demand

CrossReplicaSum

public final class CrossReplicaSum

An Op to sum inputs across replicated TPU instances.

Each instance supplies its own input.

For example, suppose there are 8 TPU instances: `[A, B, C, D, E, F, G, H]`. Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0, and `B, D, F, H` as group 1. Thus we get the outputs: `[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`.

Public Methods

Output <T>
asOutput ()
Returns the symbolic handle of a tensor.
static <T extends Number> CrossReplicaSum <T>
create ( Scope scope, Operand <T> input, Operand <Integer> groupAssignment)
Factory method to create a class wrapping a new CrossReplicaSum operation.
Output <T>
output ()
The sum of all the distributed inputs.

Inherited Methods

Public Methods

public Output <T> asOutput ()

Returns the symbolic handle of a tensor.

Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is used to obtain a symbolic handle that represents the computation of the input.

public static CrossReplicaSum <T> create ( Scope scope, Operand <T> input, Operand <Integer> groupAssignment)

Factory method to create a class wrapping a new CrossReplicaSum operation.

Parameters
scope current scope
input The local input to the sum.
groupAssignment An int32 tensor with shape [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the replica ids in the ith subgroup.
Returns
  • a new instance of CrossReplicaSum

public Output <T> output ()

The sum of all the distributed inputs.