tf.distribute.experimental.partitioners.Partitioner

Partitioner base class: all partitiners inherit from this class.

Partitioners should implement a __call__ method with the following signature:

def __call__(self, shape, dtype, axis=0):
  # Partitions the given `shape` and returns the partition results.
  # See docstring of `__call__` method for the format of partition results.

Methods

__call__

View source

Partitions the given shape and returns the partition results.

Examples of a partitioner that allocates a fixed number of shards:

partitioner = FixedShardsPartitioner(num_shards=2)
partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
print(partitions) # [2, 0]

Args
shape a tf.TensorShape, the shape to partition.
dtype a tf.dtypes.Dtype indicating the type of the partition value.
axis The axis to partition along. Default: outermost axis.

Returns
A list of integers representing the number of partitions on each axis, where i-th value correponds to i-th axis.