ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf.distribute.MirroredStrategy

Synchronous training across multiple replicas on one machine.

Inherits From: Strategy

Used in the notebooks

Used in the guide Used in the tutorials

This strategy is typically used for training on one machine with multiple GPUs. For TPUs, use tf.distribute.TPUStrategy. To use MirroredStrategy with multiple workers, please refer to tf.distribute.experimental.MultiWorkerMirroredStrategy.

For example, a variable created under a MirroredStrategy is a MirroredVariable. If no devices are specified in the constructor argument of the strategy then it will use all the available GPUs. If no GPUs are found, it will use the available CPUs. Note that TensorFlow treats all CPUs on a machine as a single device, and uses threads internally for parallelism.

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
with strategy.scope():
  x = tf.Variable(1.)
x
MirroredVariable:{
  0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>,
  1: <tf.Variable ... shape=() dtype=float32, numpy=1.0>
}

While using distribution strategies, all the variable creation should be done within the strategy's scope. This will replicate the variables across all the replicas and keep them in sync using an all-reduce algorithm.

Variables created inside a MirroredStrategy which is wrapped with a tf.function are still MirroredVariables.

x = []
@tf.function  # Wrap the function with tf.function.
def create_variable():
  if not x:
    x.append(tf.Variable(1.))
  return x[0]
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
with strategy.scope():
  _ = create_variable()
  print(x[0])
MirroredVariable:{
  0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>,
  1: <tf.Variable ... shape=() dtype=float32, numpy=1.0>
}

experimental_distribute_dataset can be used to distribute the dataset across the replicas when writing your own training loop. If you are using .fit and .compile methods available in tf.keras, then tf.keras will handle the distribution for you.

For example: