View source on GitHub
|
Sets the global dtype policy.
tf.keras.mixed_precision.set_global_policy(
policy
)
The global policy is the default tf.keras.mixed_precision.Policy used for
layers, if no policy is passed to the layer constructor.
tf.keras.mixed_precision.set_global_policy('mixed_float16')tf.keras.mixed_precision.global_policy()<Policy "mixed_float16">tf.keras.layers.Dense(10).dtype_policy<Policy "mixed_float16"># Global policy is not used if a policy is directly passed to constructortf.keras.layers.Dense(10, dtype='float64').dtype_policy<Policy "float64">tf.keras.mixed_precision.set_global_policy('float32')
If no global policy is set, layers will instead default to a Policy
constructed from tf.keras.backend.floatx().
To use mixed precision, the global policy should be set to 'mixed_float16'
or 'mixed_bfloat16', so that every layer uses a 16-bit compute dtype and
float32 variable dtype by default.
Only floating point policies can be set as the global policy, such as
'float32' and 'mixed_float16'. Non-floating point policies such as
'int32' and 'complex64' cannot be set as the global policy because most
layers do not support such policies.
See tf.keras.mixed_precision.Policy for more information.
Args | |
|---|---|
policy
|
A Policy, or a string that will be converted to a Policy. Can also
be None, in which case the global policy will be constructed from
tf.keras.backend.floatx()
|
View source on GitHub