Synchronous training on TPUs and TPU Pods.

Inherits From: Strategy

To construct a TPUStrategy object, you need to run the initialization code as below:

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
strategy = tf.distribute.experimental.TPUStrategy(resolver)

While using distribution strategies, the variables created within the strategy's scope will be replicated across all the replicas and can be kept in sync using all-reduce algorithms.

To run TF2 programs on TPUs, you can either use .compile and .fit APIs in tf.keras with TPUStrategy, or write your own customized training loop by calling directly. Note that TPUStrategy doesn't support pure eager execution, so please make sure the function passed into is a tf.function or is called inside a tf.function if eager behavior is enabled.

tpu_cluster_resolver A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster.
device_assignment Optional tf.tpu.experimental.DeviceAssignment to specify the placement of replicas on the TPU cluster.

cluster_resolver Returns the cluster resolver associated with this strategy.

tf.distribute.experimental.TPUStrategy provides the associated tf.distribute.cluster_resolver.ClusterResolver. If the user provides one in __init__, that instance is returned; if the user does not, a default tf.distribute.cluster_resolver.TPUClusterResolver is provided.

extended tf.distribute.StrategyExtended with additional methods.
num_replicas_in_sync Returns number of replicas over which gradients are aggregated.



View source

Distributes instances created by calls to dataset_fn.

The argument dataset_fn that users pass in is an input function that has a tf.distribute.InputContext argument and returns a instance. It is expected that the returned dataset from dataset_fn is already batched by per-replica batch size (i.e. global batch size divided by the number of replicas in sync) and sharded. tf.distribute.Strategy.distribute_datasets_from_function does not batch or shard the instance returned from the input function. dataset_fn will be called on the CPU device of each of the workers and each generates a dataset where every replica on that worker will dequeue one batch of inputs (i.e. if a worker has two replicas, two batches will be dequeued from the Dataset every step).

This method can be used for several purposes. First, it allows you to specify your own batching and sharding logic. (In contrast, tf.distribute.experimental_distribute_dataset does batching and sharding for you.) For example, where experimental_distribute_dataset is unable to shard the input files, this method might be used to manually shard the dataset (avoiding the slow fallback behavior in experimental_distribute_dataset). In cases where the dataset is infinite, this sharding can be done by creating dataset replicas that differ only in their random seed.

The dataset_fn should take an tf.distribute.InputContext instance where information about batching and input replication can be accessed.

You can use element_spec property of the tf.distribute.DistributedDataset returned by this API to query the tf.TypeSpec of the elements returned by the iterator. This can be used to set the input_signature property of a tf.function. Follow tf.distribute.DistributedDataset.element_spec to see an example.