![]() |
Synchronous training on TPUs and TPU Pods.
Inherits From: Strategy
tf.distribute.TPUStrategy(
tpu_cluster_resolver=None, experimental_device_assignment=None
)
To construct a TPUStrategy object, you need to run the initialization code as below:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
While using distribution strategies, the variables created within the strategy's scope will be replicated across all the replicas and can be kept in sync using all-reduce algorithms.
To run TF2 programs on TPUs, you can either use .compile
and
.fit
APIs in tf.keras
with TPUStrategy, or write your own customized
training loop by calling strategy.run
directly. Note that
TPUStrategy doesn't support pure eager execution, so please make sure the
function passed into strategy.run
is a tf.function
or
strategy.run
is called inside a tf.function
if eager
behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu.
experimental_distribute_datasets_from_function
and
experimental_distribute_dataset
APIs can be used to distribute the dataset
across the TPU workers when writing your own training loop. If you are using
fit
and compile
methods available in tf.keras.Model
, then Keras will
handle the distribution for you.
An example of writing customized training loop on TPUs:
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Dense(2, input_shape=(5,)),
])
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
def dataset_fn(ctx):
x = np.random.random((2, 5)).astype(np.float32)
y = np.random.randint(2, size=(2, 1))
dataset = tf.data.Dataset.from_tensor_slices((x, y))
return dataset.repeat().batch(1, drop_remainder=True)
dist_dataset = strategy.experimental_distribute_datasets_from_function(
dataset_fn)
iterator = iter(dist_dataset)
@tf.function()
def train_step(iterator):
def step_fn(inputs):
features, labels = inputs
with tf.GradientTape() as tape:
logits = model(features, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
strategy.run(step_fn, args=(next(iterator),))
train_step(iterator)
For the advanced use cases like model parallelism, you can set
experimental_device_assignment
argument when creating TPUStrategy to specify
number of replicas and number of logical devices. Below is an example to
initialize TPU system with 2 logical devices and 1 replica.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
computation_shape=[1, 1, 1, 2],
num_replicas=1)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment)
Then you can run a tf.add
operation only on logical device 0.
@tf.function()
def step_fn(inputs):
features, _ = inputs
output = tf.add(features, features)
# Add operation will be executed on logical device 0.
output = strategy.experimental_assign_to_logical_device(output, 0)
return output
dist_dataset = strategy.experimental_distribute_datasets_from_function(
dataset_fn)
iterator = iter(dist_dataset)
strategy.run(step_fn, args=(next(iterator),))
Args | |
---|---|
tpu_cluster_resolver
|
A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. If None, it will assume running on a local TPU worker. |
experimental_device_assignment
|
Optional
tf.tpu.experimental.DeviceAssignment to specify the placement of
replicas on the TPU cluster.
|
Attributes | |
---|---|
cluster_resolver
|
Returns the cluster resolver associated with this strategy.
In general, when using a multi-worker Strategies that intend to have an associated
Single-worker strategies usually do not have a
The
For more information, please see
|
extended
|
tf.distribute.StrategyExtended with additional methods.
|
num_replicas_in_sync
|
Returns number of replicas over which gradients are aggregated. |
Methods
experimental_assign_to_logical_device
experimental_assign_to_logical_device(
tensor, logical_device_id
)
Adds annotation that tensor
will be assigned to a logical device.
# Initializing TPU system with 2 logical devices and 4 replicas.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
computation_shape=[1, 1, 1, 2],
num_replicas=4)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment)
iterator = iter(inputs)
@tf.function()
def step_fn(inputs):
output = tf.add(inputs, inputs)
# Add operation will be executed on logical device 0.
output = strategy.experimental_assign_to_logical_device(output, 0)
return output
strategy.run(step_fn, args=(next(iterator),))
Args | |
---|---|
tensor
|
Input tensor to annotate. |
logical_device_id
|
Id of the logical core to which the tensor will be assigned. |
Raises | |
---|---|
ValueError
|
The logical device id presented is not consistent with total number of partitions specified by the device assignment. |
Returns | |
---|---|
Annotated tensor with idential value as tensor .
|
experimental_distribute_dataset
experimental_distribute_dataset(
dataset, options=None
)
Creates tf.distribute.DistributedDataset
from tf.data.Dataset
.
The returned tf.distribute.DistributedDataset
can be iterated over
similar to how regular datasets can.
NOTE: The user cannot add any more transformations to a
tf.distribute.DistributedDataset
.
The following is an example:
strategy = tf.distribute.MirroredStrategy()
# Create a dataset
dataset = dataset_ops.Dataset.TFRecordDataset([
"/a/1.tfr", "/a/2.tfr", "/a/3.tfr", "/a/4.tfr"])
# Distribute that dataset
dist_dataset = strategy.experimental_distribute_dataset(dataset)
# Iterate over the `tf.distribute.DistributedDataset`
for x in dist_dataset:
# process dataset elements
strategy.run(replica_fn, args=(x,))
In the code snippet above, the tf.distribute.DistributedDataset
dist_dataset
is batched by GLOBAL_BATCH_SIZE
, and we iterate through it
using for x in dist_dataset
. x
a tf.distribute.DistributedValues
containing data for all replicas, which aggregates to a batch of
GLOBAL_BATCH_SIZE
.