tf.keras.layers.experimental.SyncBatchNormalization

Normalize and scale inputs or activations synchronously across replicas.

Inherits From: Layer, Module

Applies batch normalization to activations of the previous layer at each batch by synchronizing the global batch statistics across all devices that are training the model. For specific details about batch normalization please refer to the tf.keras.layers.BatchNormalization layer docs.

If this layer is used when using tf.distribute strategy to train models across devices/workers, there will be an allreduce call to aggregate batch statistics across all replicas at every training step. Without tf.distribute strategy, this layer behaves as a regular tf.keras.layers.BatchNormalization layer.

Example usage:

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(16))
  model.add(tf.keras.layers.experimental.SyncBatchNormalization())

axis Integer, the axis that should be normalized (typically the features axis). For instance, after a Conv2D layer with data_format="channels_first", set axis=1 in BatchNormalization.
momentum Momentum for the moving average.
epsilon Small float added to variance to avoid dividing by zero.
center If True, add offset of beta to normalized tensor. If False, beta is ignored.
scale If True, multiply by gamma. If False, gamma is not used. When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
beta_initializer Initializer for the beta weight.
gamma_initializer I