View source on GitHub
|
Policy that splits tensors into shards with a max shard size.
Inherits From: ShardingCallback
tf.train.experimental.MaxShardSizePolicy(
max_shard_size: int
)
Shards may exceed the max shard size if they contain 1. a single scalar/string tensor that could not be sliced and exceeds the max shard size or 2. the checkpoint object graph, whose size cannot be calculated when saving.
Attributes | |
|---|---|
description
|
|
Methods
__call__
__call__(
shardable_tensors: Sequence[tf.train.experimental.ShardableTensor]
) -> Sequence[sharding_util.TensorSliceDict]
Callback to split tensors into shards with a max shard size.
| Args | |
|---|---|
shardable_tensors
|
A list of ShardableTensors. |
| Returns | |
|---|---|
| List of shard dicts containing tensors. [ {checkpoint key: {slice_spec: tensor} } ] |
View source on GitHub