|TensorFlow 1 version||View source on GitHub|
Enable mixed precision via a graph rewrite. (deprecated)
tf.train.experimental.enable_mixed_precision_graph_rewrite( opt, loss_scale='dynamic' )
Mixed precision is the use of both float32 and float16 data types when training a model to improve performance. This is achieved via a graph rewrite operation and a loss-scale optimizer.
Performing arithmetic operations in float16 takes advantage of specialized processing units, such as NVIDIA Tensor Cores, for much higher arithmetic throughput. However, due to the smaller representable range, performing the entire training with float16 can result in gradient underflow, that is, small gradient values becoming zeroes. Instead, performing only select arithmetic operations in float16 results in higher throughput and decreased training time when using compatible hardware accelerators while also reducing memory usage, typically without sacrificing model accuracy.
model = tf.keras.models.Sequential([ tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(64, activation='softmax'), ]) opt = tf.keras.optimizers.SGD() opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt) model.compile(loss="mse", optimizer=opt) x_train = np.random.random((1024, 64)) y_train = np.random.random((1024, 64)) model.fit(x_train, y_train)
enable_mixed_precision_graph_rewrite(opt) enables the graph rewrite
operation before computing gradients. The function additionally returns an
opt) wrapped with a
LossScaleOptimizer. This prevents
underflow in the float16 tensors during the backward pass. An optimizer of
tf.compat.v1.train.Optimizer must be
passed to this function, which will then be wrapped to use loss scaling.
The graph rewrite operation changes the dtype of certain operations in the
graph from float32 to float16. There are several categories of operations
that are either included or excluded by this rewrite operation. The following
categories of Ops are defined inside corresponding functions under the class
ClearList: Ops that do not have numerically significant adverse effects. E.g.
AllowList: Ops that are considered numerically safe for execution in float16, and thus are always converted. E.g.
DenyList: Ops that are numerically unsafe to execute in float16 and can negatively affect downstream nodes. E.g.
GrayList: Ops that are considered numerically safe for execution in float16 unless downstream from a DenyList Op. E.g.
When this function is used, gradients should be computed and applied with the
returned optimizer, either by calling
opt.compute_gradients() followed by
opt.apply_gradients(). If gradients
are instead computed with
tf.GradientTape, loss scaling
will not be applied, which will likely cause your model not to converge due to
float16 underflow problems. To apply lossing scaling with
keras.mixed_precision.experimental.LossScaleOptimizer for details how to do
For NVIDIA GPUs with Tensor cores, as a general performance guide, dimensions (such as batch size, input size, output size, and channel counts) should be powers of two if under 256, or otherwise divisible by 8 if above 256. For more information, check out the NVIDIA Deep Learning Performance Guide.
Currently, mixed precision is only enabled on NVIDIA Tensor Core GPUs with Compute Capability 7.0 and above (Volta, Turing, or newer architectures). The parts of the graph on CPUs and TPUs are untouched by the graph rewrite.
Comparison with the Keras mixed precision API
Both this function and the Keras mixed precision
API enable the use of
mixed precision in a model. Therefore, only one of the two APIs can be used.
We recommend using the Keras mixed precision API, as it is more customizable
and supports Eager execution. However, it only supports models which use Keras
layers, while the graph rewrite works in any model that uses
The core difference between the two APIs is that this function is a graph rewrite, and so it changes the graph to use mixed precision under the hood. You still build your graph in float32, and the graph rewrite will change certain ops to float16. The Keras mixed precision API directly builds the Keras Model using a mix of float16 and float32.
One core advantage of the Keras API is it supports mixed precision with Eager
execution, i.e. mixed precision outside
tf.functions. The graph rewrite will
only affect ops within
tf.functions, making it harder to debug if issues
occur with mixed precision. The Keras API is also more customizable, as you
can override any layer to run in float32 by passing
dtype="float32" to the
layer constructor. Additionally, you can query the dtype of tensors in the
model by checking
tensor.dtype. With the graph rewrite, all tensors appear
to be float32 since the dtype is only changed under the hood.
The main advantage of the graph rewrite (this function) is that it works even
if you do not use Keras layers or any other part of Keras. The Keras mixed
precision API requires models which use Keras layers, as it only inserts casts
inside Keras layers and models. Another advantage is that the graph rewrite
never results in a TypeError, which the Keras API may introduce if you do
certain operations outside Keras. For example, the following will result in a
TypeError if the Keras mixed precision API is enabled, as a float16 and
float32 tensor will be added:
tf.keras.layers.Dense(2)(x) + tf.keras.layers.Dense(2, dtype="float32")(x)
An instance of a
Either an int/float, the string
A version of