RSVP for your your local TensorFlow Everywhere event today!


Partitioner that keeps shards below max_shard_bytes.

Inherits From: Partitioner

This partitioner ensures each shard has at most max_shard_bytes, and tries to allocate as few shards as possible, i.e., keeping shard size as large as possible.

If the partitioner hits the max_shards limit, then each shard may end up larger than max_shard_bytes. By default max_shards equals None and no limit on the number of shards is enforced.


partitioner = MaxSizePartitioner(max_shard_bytes=4)
partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
[6, 1]
partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2)
partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
[2, 1]
partitioner = MaxSizePartitioner(max_shard_bytes=1024)
partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
[1, 1]

# use in ParameterServerStrategy
# strategy = tf.distribute.experimental.ParameterServerStrategy(
#   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)

max_shard_bytes The maximum size any given shard is allowed to be.
max_shards The maximum number of shards in int created taking precedence over max_shard_bytes.
bytes_per_string If the partition value is of type string, this provides an estimate of how large each string is.



View source

Partitions the given shape and returns the partition results.

Examples of a partitioner that allocates a fixed number of shards:

partitioner = FixedShardsPartitioner(num_shards=2)
partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
print(partitions) # [2, 0]

shape a tf.TensorShape, the shape to partition.
dtype a tf.dtypes.Dtype indicating the type of the partition value.
axis The axis to partition along. Default: outermost axis.

A list of integers representing the number of partitions on each axis, where i-th value correponds to i-th axis.