Policy that splits tensors into shards based on their device spec task.
Inherits From: ShardingCallback
View aliases
Compat aliases for migration
See Migration guide for more details.
Methods
__call__
__call__(
shardable_tensors: Sequence[tf.train.experimental.ShardableTensor
]
) -> Sequence[sharding_util.TensorSliceDict]
Callback to split tensors into shards based on their device spec task.
Args | |
---|---|
shardable_tensors
|
A list of ShardableTensors. |
Returns | |
---|---|
List of shard dicts containing tensors. [ {checkpoint key: {slice_spec: tensor} } ] |