tf.distribute.HierarchicalCopyAllReduce

Reduction using hierarchical copy all-reduce.

Compat aliases for migration

See Migration guide for more details.

tf.compat.v1.distribute.HierarchicalCopyAllReduce

It reduces to one GPU along edges in some hierarchy and broadcasts back to each GPU along the same path. Before performing all-reduce, tensors will be repacked or aggregated for more efficient cross-device transportation.

This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like that on DGX-1 machine. If you have different GPU inter-connections, it is likely that it would be slower than tf.distribute.ReductionToOneDevice.

num_packs values will be packed in this many splits. num_packs should be greater than 0.

ValueError if num_packs is zero or negative.

Methods

batch_reduce

View source

Reduce PerReplica objects in a batch.

Reduce each first element in value_destination_pairs to each second element which indicates the destinations.

Args
reduce_op Indicates how per_replica_value will be reduced. Accepted values are tf.distribute.ReduceOp.SUM, tf.distribute.ReduceOp.MEAN.
value_destination_pairs a list or a tuple of tuples of PerReplica objects (or tensors with device set if there is one device) and destinations.

Returns
a list of Mirrored objects.

Raises
ValueError if value_destination_pairs is not a list or a tuple of tuples of PerReplica objects and destinations

broadcast

View source

Broadcast the tensor to destinations.

Args
tensor the tensor to broadcast.
destinations the broadcast destinations.

Returns
a Mirrored object.

reduce

View source

Reduce per_replica_value to destinations.

It runs the reduction operation defined by reduce_op and put the result on destinations.

Args
reduce_op Indicates how per_replica_value will be reduced. Accepted values are tf.distribute.ReduceOp.SUM, tf.distribute.ReduceOp.MEAN.
per_replica_value a PerReplica object or a tensor with device set.
destinations the reduction destinations.

Returns
a Mirrored object.

Raises
ValueError if per_replica_value can't be converted to a PerReplica object.