tf.keras.mixed_precision.experimental.Policy

A dtype policy for a Keras layer.

Used in the notebooks

Used in the guide

A dtype policy determines dtype-related aspects of a layer, such as its computation and variable dtypes. Each layer has a policy. Policies can be passed to the dtype argument of layer constructors, or a global policy can be set with tf.keras.mixed_precision.experimental.set_policy. A layer will default to the global policy if no policy is passed to it's constructor.

For many models, each layer's policy will have the same compute dtype and variable dtype, which will typically be float32. In this case, we refer to the singular dtype as the layer's dtype, which can be queried by the property tf.keras.layers.Layer.dtype.

When mixed precision training is used, most layers will instead have a float16 or bfloat16 compute dtype and a float32 variable dtype, and so the layer does not have a single dtype. When the variable dtype does not match the compute dtype, variables will be automatically casted to the compute dtype to avoid type errors. In this case, tf.keras.layers.Layer.dtype refers to the variable dtype, not the compute dtype. See the mixed precision guide for more information on how to use mixed precision.

Certain policies also have a tf.mixed_precision.experimental.LossScale instance, which is used by tf.keras.Models to performance loss scaling. Loss scaling is a technique used with mixed precision to avoid numerical underflow in float16 gradients. Loss scaling is only done by Models in Model.fit, Model.train_on_batch, and similar methods. Layers which are not Models ignore the loss scale.

Policies are constructed by passing a string to the constructor, e.g. tf.keras.mixed_precision.experimental.Policy('float32'). The string determines the compute and variable dtypes. It can be one of the following:

  • Any dtype name, such as 'float32' or 'float64'. Both the variable and compute dtypes will be that dtype. No loss scaling is done by default.
  • 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or bfloat16, while the variable dtype is float32. These policies are used for mixed precision training. With 'mixed_float16', a dynamic loss scale is used by default. 'mixed_bfloat16' does no loss scaling by default, as loss scaling is unnecessary with bfloat16.

How to use mixed precision in a Keras model

To use mixed precision in a Keras model, the 'mixed_float16' or 'mixed_bfloat16' policy can be used. tf.keras.mixed_precision.experimental.set_policy can be used to set the default policy for layers if no policy is passed to them. For example:

tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
model = tf.keras.models.Sequential([
    tf.keras.layers.Input((100,)),
    # Dense layers use global policy of 'mixed_float16', which does
    # computations in float16 while keeping variables in float32.
    tf.keras.layers.Dense(10),
    tf.keras.layers.Dense(10),
    # Softmax should be done in float32 for numeric stability. We pass
    # dtype='float32' to use float32 instead of the global policy.
    tf.keras.layers.Activation('softmax', dtype='float32')
])

Alternatively, the policy can be passed to individual layers instead of setting the global policy with set_policy:

policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
model = tf.keras.models.Sequential([
    tf.keras.layers.Input((100,)),
    tf.keras.layers.Dense(10, dtype=policy),
    tf.keras.layers.Dense(10, dtype=policy),
    # Softmax should be done in float32 for numeric stability.
    tf.keras.layers.Activation('softmax', dtype='float32')
])

Note the 'mixed_float16' policy will apply loss scaling by default in Model.fit, Model.train_on_batch, and other training methods. If no such method is used (e.g., a custom training loop is used) and 'mixed_float16' is used, the loss scale must be manually applied. See tf.keras.mixed_precision.experimental.LossScaleOptimizer for details. For 'mixed_bfloat16', no loss scaling is done and loss scaling never needs to be manually applied.

See the mixed precision guide for more information on using mixed precision

How to use float64 in a Keras model

Using float64 is similar to mixed precision. Either the global policy can be set to float64, or dtype='float64' can be passed to individual layers. For example, to set the global policy:

tf.keras.mixed_precision.experimental.set_policy('float64')
model = tf.keras.models.Sequential([
    tf.keras.layers.Input((100,)),
    # All layers use global policy of 'float64', which does computations
    # and creates variables in float64.
    tf.keras.layers.Dense(10),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
])
# Optionaly set policy back to float32 if any other models use float32
tf.keras.mixed_precision.experimental.set_policy('float32')

How a layer uses its policy's compute dtype

A layer will cast its inputs to its compute dtype in TensorFlow 2. For example:

x = tf.ones((4, 4, 4, 4), dtype='float64')
# `layer`'s policy defaults to float32.
layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
# `layer` casts it's inputs to its compute dtype, which is float32, and
# does computations in float32.
y = layer(x)
y.dtype
tf.float32

Note that the base tf.keras.layers.Layer class inserts the casts. If subclassing your own layer, you do not have to insert any casts.

Currently, only tensors in the first argument to the layer's call method are casted. For example:

class MyLayer(tf.keras.layers.Layer):
  # Bug! `b` will not be casted.
  def call(self, a, b):
    return a + 1., b + 1.
a = tf.constant(1., dtype="float32")
b = tf.constant(1., dtype="float32")
layer = MyLayer(dtype="float64")
x, y = layer(a, b)
x.dtype
tf.float64
y.dtype
tf.float32

If writing your own layer, it is recommended to accept tensors only in the first argument. This way, all tensors are casted to the layer's compute dtype. MyLayer should therefore be written as:

class MyLayer(tf.keras.layers.Layer):
  # Now, all tensor inputs will be casted.
  def call(self, inputs):
    a, b = inputs
    return a + 1., b + 1.
a = tf.constant(1., dtype="float32")
b = tf.constant(1., dtype="float32")
layer = MyLayer(dtype="float64")
x, y = layer((a, b))
x.dtype
tf.float64
y.dtype
tf.float64

Other arguments are not automatically casted for technical reasons, but this may change in a future minor release.

The casting only occurs in TensorFlow 2, but can be enabled if tf.compat.v1.disable_v2_behavior() has been called with tf.compat.v1.keras.layers.enable_v2_dtype_behavior().

A layer subclass can prevent its inputs from being autocasted by passing autocast=False to the layer constructor. For example:

class NonAutoCastingLayer(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    kwargs['autocast'] = False
    super(NonAutoCas