|View source on GitHub|
TPU distribution strategy implementation.
tf.compat.v1.distribute.experimental.TPUStrategy( tpu_cluster_resolver=None, steps_per_run=None, device_assignment=None )
||A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster.|
||Number of steps to run on device before returning to the host. Note that this can have side-effects on performance, hooks, metrics, summaries etc. This parameter is only used when Distribution Strategy is used with estimator or keras.|
Returns the cluster resolver associated with this strategy.
In general, when using a multi-worker
Strategies that intend to have an associated
Single-worker strategies usually do not have a
For more information, please see
||Returns number of replicas over which gradients are aggregated.|
||DEPRECATED: use .extended.steps_per_run instead.|
experimental_distribute_dataset( dataset, options=None )
The following is an example:
strategy = tf.distribute.MirroredStrategy() # Create a dataset dataset = dataset_ops.Dataset.TFRecordDataset([ "/a/1.tfr", "/a/2.tfr", "/a/3.tfr", "/a/4.tfr"]) # Distribute that dataset dist_dataset = strategy.experimental_distribute_dataset(dataset) # Iterate over the `tf.distribute.DistributedDataset` for x in dist_dataset: # process dataset elements strategy.run(replica_fn, args=(x,))
In the code snippet above, the
dist_dataset is batched by
GLOBAL_BATCH_SIZE, and we iterate through it
for x in dist_dataset.
containing data for all replicas, which aggregates to a batch of
tf.distribute.Strategy.run will take care of feeding
the right per-replica data in
x to the right
replica_fn executed on each
What's under the hood of this method, when we say the
dataset - gets distributed? It depends on how you set the
tf.data.experimental.DistributeOptions. By default, it is set to
tf.data.experimental.AutoShardPolicy.AUTO. In a multi-worker setting, we
will first attempt to distribute
dataset by detecting whether
being created out of reader datasets (e.g.
tf.data.TextLineDataset, etc.) and if so, try to shard the input files.
Note that there has to be at least one input file per worker. If you have
less than one input file per worker, we suggest that you disable dataset
sharding across workers, by setting the
tf.data.experimental.DistributeOptions.auto_shard_policy to be
If the attempt to shard by file is unsuccessful (i.e. the dataset is not
read from files), we will shard the dataset evenly at the end by
.shard operation to the end of the processing pipeline. This
will cause the entire preprocessing pipeline for all the data to be run on
every worker, and each worker will do redundant work. We will print a
warning if this route is selected.
As mentioned before, within each worker, we will also split the data among all the worker devices (if more than one a present). This will happen even if multi-worker sharding is disabled.
If the above batch splitting and dataset sharding logic is undesirable,